diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
index 740f12e7d13d..4cd2c7a62f46 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
@@ -62,7 +62,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP
val stageId = stage.stageId
val attemptId = stage.attemptId
val name = stage.name
- val status = stage.status.toString
+ val status = stage.status.toString.toLowerCase(Locale.ROOT)
val submissionTime = stage.submissionTime.get.getTime()
val completionTime = stage.completionTime.map(_.getTime())
.getOrElse(System.currentTimeMillis())
@@ -201,7 +201,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP
val stages = jobData.stageIds.map { stageId =>
// This could be empty if the listener hasn't received information about the
// stage or if the stage information has been garbage collected
- store.stageData(stageId).lastOption.getOrElse {
+ store.asOption(store.lastStageAttempt(stageId)).getOrElse {
new v1.StageData(
v1.StageStatus.PENDING,
stageId,
@@ -336,8 +336,14 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP
content ++= makeTimeline(activeStages ++ completedStages ++ failedStages,
store.executorList(false), appStartTime)
- content ++= UIUtils.showDagVizForJob(
- jobId, store.operationGraphForJob(jobId))
+ val operationGraphContent = store.asOption(store.operationGraphForJob(jobId)) match {
+ case Some(operationGraph) => UIUtils.showDagVizForJob(jobId, operationGraph)
+ case None =>
+
+ No DAG visualization information to display for job {jobId}
+
+ }
+ content ++= operationGraphContent
if (shouldShowActiveStages) {
content ++= Active Stages ({activeStages.size}) ++
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
index 99eab1b2a27d..ff1b75e5c506 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
@@ -34,10 +34,10 @@ private[ui] class JobsTab(parent: SparkUI, store: AppStatusStore)
val killEnabled = parent.killEnabled
def isFairScheduler: Boolean = {
- store.environmentInfo().sparkProperties.toMap
- .get("spark.scheduler.mode")
- .map { mode => mode == SchedulingMode.FAIR }
- .getOrElse(false)
+ store
+ .environmentInfo()
+ .sparkProperties
+ .contains(("spark.scheduler.mode", SchedulingMode.FAIR.toString))
}
def getSparkUser: String = parent.getSparkUser
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 11a6a3434497..73e78aac955d 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -19,25 +19,23 @@ package org.apache.spark.ui.jobs
import java.net.URLEncoder
import java.util.Date
+import java.util.concurrent.TimeUnit
import javax.servlet.http.HttpServletRequest
import scala.collection.mutable.{HashMap, HashSet}
-import scala.xml.{Elem, Node, Unparsed}
+import scala.xml.{Node, Unparsed}
import org.apache.commons.lang3.StringEscapeUtils
-import org.apache.spark.SparkConf
-import org.apache.spark.internal.config._
import org.apache.spark.scheduler.TaskLocality
-import org.apache.spark.status.AppStatusStore
+import org.apache.spark.status._
import org.apache.spark.status.api.v1._
import org.apache.spark.ui._
-import org.apache.spark.util.{Distribution, Utils}
+import org.apache.spark.util.Utils
/** Page showing statistics and task list for a given stage */
private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends WebUIPage("stage") {
import ApiHelper._
- import StagePage._
private val TIMELINE_LEGEND = {
@@ -67,17 +65,17 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
// if we find that it's okay.
private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000)
- private def getLocalitySummaryString(stageData: StageData, taskList: Seq[TaskData]): String = {
- val localities = taskList.map(_.taskLocality)
- val localityCounts = localities.groupBy(identity).mapValues(_.size)
+ private def getLocalitySummaryString(localitySummary: Map[String, Long]): String = {
val names = Map(
TaskLocality.PROCESS_LOCAL.toString() -> "Process local",
TaskLocality.NODE_LOCAL.toString() -> "Node local",
TaskLocality.RACK_LOCAL.toString() -> "Rack local",
TaskLocality.ANY.toString() -> "Any")
- val localityNamesAndCounts = localityCounts.toSeq.map { case (locality, count) =>
- s"${names(locality)}: $count"
- }
+ val localityNamesAndCounts = names.flatMap { case (key, name) =>
+ localitySummary.get(key).map { count =>
+ s"$name: $count"
+ }
+ }.toSeq
localityNamesAndCounts.sorted.mkString("; ")
}
@@ -108,7 +106,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
val stageHeader = s"Details for Stage $stageId (Attempt $stageAttemptId)"
val stageData = parent.store
- .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = true))
+ .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = false))
.getOrElse {
val content =
@@ -117,8 +115,11 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
return UIUtils.headerSparkPage(stageHeader, content, parent)
}
- val tasks = stageData.tasks.getOrElse(Map.empty).values.toSeq
- if (tasks.isEmpty) {
+ val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId)
+
+ val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks +
+ stageData.numFailedTasks + stageData.numKilledTasks
+ if (totalTasks == 0) {
val content =
Summary Metrics No tasks have started yet
@@ -127,18 +128,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
return UIUtils.headerSparkPage(stageHeader, content, parent)
}
+ val storedTasks = store.taskCount(stageData.stageId, stageData.attemptId)
val numCompleted = stageData.numCompleteTasks
- val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks +
- stageData.numFailedTasks + stageData.numKilledTasks
- val totalTasksNumStr = if (totalTasks == tasks.size) {
+ val totalTasksNumStr = if (totalTasks == storedTasks) {
s"$totalTasks"
} else {
- s"$totalTasks, showing ${tasks.size}"
+ s"$totalTasks, showing $storedTasks"
}
- val externalAccumulables = stageData.accumulatorUpdates
- val hasAccumulators = externalAccumulables.size > 0
-
val summary =
@@ -148,7 +145,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
-
Locality Level Summary:
- {getLocalitySummaryString(stageData, tasks)}
+ {getLocalitySummaryString(localitySummary)}
{if (hasInput(stageData)) {
-
@@ -261,12 +258,16 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value")
def accumulableRow(acc: AccumulableInfo): Seq[Node] = {
-
| {acc.name} | {acc.value} |
+ if (acc.name != null && acc.value != null) {
+ | {acc.name} | {acc.value} |
+ } else {
+ Nil
+ }
}
val accumulableTable = UIUtils.listingTable(
accumulableHeaders,
accumulableRow,
- externalAccumulables.toSeq)
+ stageData.accumulatorUpdates.toSeq)
val page: Int = {
// If the user has changed to a larger page size, then go to page 1 in order to avoid
@@ -280,16 +281,9 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
val currentTime = System.currentTimeMillis()
val (taskTable, taskTableHTML) = try {
val _taskTable = new TaskPagedTable(
- parent.conf,
+ stageData,
UIUtils.prependBaseUri(parent.basePath) +
s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}",
- tasks,
- hasAccumulators,
- hasInput(stageData),
- hasOutput(stageData),
- hasShuffleRead(stageData),
- hasShuffleWrite(stageData),
- hasBytesSpilled(stageData),
currentTime,
pageSize = taskPageSize,
sortColumn = taskSortColumn,
@@ -320,217 +314,155 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
| }
|});
""".stripMargin
- }
+ }
}
- val taskIdsInPage = if (taskTable == null) Set.empty[Long]
- else taskTable.dataSource.slicedTaskIds
-
- // Excludes tasks which failed and have incomplete metrics
- val validTasks = tasks.filter(t => t.status == "SUCCESS" && t.taskMetrics.isDefined)
+ val metricsSummary = store.taskSummary(stageData.stageId, stageData.attemptId,
+ Array(0, 0.25, 0.5, 0.75, 1.0))
- val summaryTable: Option[Seq[Node]] =
- if (validTasks.size == 0) {
- None
- } else {
- def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = {
- Distribution(data).get.getQuantiles()
- }
- def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = {
- getDistributionQuantiles(times).map { millis =>
- {UIUtils.formatDuration(millis.toLong)} |
- }
- }
- def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = {
- getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)} | )
+ val summaryTable = metricsSummary.map { metrics =>
+ def timeQuantiles(data: IndexedSeq[Double]): Seq[Node] = {
+ data.map { millis =>
+ {UIUtils.formatDuration(millis.toLong)} |
}
+ }
- val deserializationTimes = validTasks.map { task =>
- task.taskMetrics.get.executorDeserializeTime.toDouble
- }
- val deserializationQuantiles =
-
-
- Task Deserialization Time
-
- | +: getFormattedTimeQuantiles(deserializationTimes)
-
- val serviceTimes = validTasks.map(_.taskMetrics.get.executorRunTime.toDouble)
- val serviceQuantiles = Duration | +: getFormattedTimeQuantiles(serviceTimes)
-
- val gcTimes = validTasks.map(_.taskMetrics.get.jvmGcTime.toDouble)
- val gcQuantiles =
-
- GC Time
-
- | +: getFormattedTimeQuantiles(gcTimes)
-
- val serializationTimes = validTasks.map(_.taskMetrics.get.resultSerializationTime.toDouble)
- val serializationQuantiles =
-
-
- Result Serialization Time
-
- | +: getFormattedTimeQuantiles(serializationTimes)
-
- val gettingResultTimes = validTasks.map(getGettingResultTime(_, currentTime).toDouble)
- val gettingResultQuantiles =
-
-
- Getting Result Time
-
- | +:
- getFormattedTimeQuantiles(gettingResultTimes)
-
- val peakExecutionMemory = validTasks.map(_.taskMetrics.get.peakExecutionMemory.toDouble)
- val peakExecutionMemoryQuantiles = {
-
-
- Peak Execution Memory
-
- | +: getFormattedSizeQuantiles(peakExecutionMemory)
+ def sizeQuantiles(data: IndexedSeq[Double]): Seq[Node] = {
+ data.map { size =>
+ {Utils.bytesToString(size.toLong)} |
}
+ }
- // The scheduler delay includes the network delay to send the task to the worker
- // machine and to send back the result (but not the time to fetch the task result,
- // if it needed to be fetched from the block manager on the worker).
- val schedulerDelays = validTasks.map { task =>
- getSchedulerDelay(task, task.taskMetrics.get, currentTime).toDouble
- }
- val schedulerDelayTitle = Scheduler Delay |
- val schedulerDelayQuantiles = schedulerDelayTitle +:
- getFormattedTimeQuantiles(schedulerDelays)
- def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double])
- : Seq[Elem] = {
- val recordDist = getDistributionQuantiles(records).iterator
- getDistributionQuantiles(data).map(d =>
- {s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"} |
- )
+ def sizeQuantilesWithRecords(
+ data: IndexedSeq[Double],
+ records: IndexedSeq[Double]) : Seq[Node] = {
+ data.zip(records).map { case (d, r) =>
+ {s"${Utils.bytesToString(d.toLong)} / ${r.toLong}"} |
}
+ }
- val inputSizes = validTasks.map(_.taskMetrics.get.inputMetrics.bytesRead.toDouble)
- val inputRecords = validTasks.map(_.taskMetrics.get.inputMetrics.recordsRead.toDouble)
- val inputQuantiles = Input Size / Records | +:
- getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords)
+ def titleCell(title: String, tooltip: String): Seq[Node] = {
+
+
+ {title}
+
+ |
+ }
- val outputSizes = validTasks.map(_.taskMetrics.get.outputMetrics.bytesWritten.toDouble)
- val outputRecords = validTasks.map(_.taskMetrics.get.outputMetrics.recordsWritten.toDouble)
- val outputQuantiles = Output Size / Records | +:
- getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords)
+ def simpleTitleCell(title: String): Seq[Node] = {title} |
- val shuffleReadBlockedTimes = validTasks.map { task =>
- task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime.toDouble
- }
- val shuffleReadBlockedQuantiles =
-
-
- Shuffle Read Blocked Time
-
- | +:
- getFormattedTimeQuantiles(shuffleReadBlockedTimes)
-
- val shuffleReadTotalSizes = validTasks.map { task =>
- totalBytesRead(task.taskMetrics.get.shuffleReadMetrics).toDouble
- }
- val shuffleReadTotalRecords = validTasks.map { task =>
- task.taskMetrics.get.shuffleReadMetrics.recordsRead.toDouble
- }
- val shuffleReadTotalQuantiles =
-
-
- Shuffle Read Size / Records
-
- | +:
- getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords)
-
- val shuffleReadRemoteSizes = validTasks.map { task =>
- task.taskMetrics.get.shuffleReadMetrics.remoteBytesRead.toDouble
- }
- val shuffleReadRemoteQuantiles =
-
-
- Shuffle Remote Reads
-
- | +:
- getFormattedSizeQuantiles(shuffleReadRemoteSizes)
-
- val shuffleWriteSizes = validTasks.map { task =>
- task.taskMetrics.get.shuffleWriteMetrics.bytesWritten.toDouble
- }
+ val deserializationQuantiles = titleCell("Task Deserialization Time",
+ ToolTips.TASK_DESERIALIZATION_TIME) ++ timeQuantiles(metrics.executorDeserializeTime)
- val shuffleWriteRecords = validTasks.map { task =>
- task.taskMetrics.get.shuffleWriteMetrics.recordsWritten.toDouble
- }
+ val serviceQuantiles = simpleTitleCell("Duration") ++ timeQuantiles(metrics.executorRunTime)
- val shuffleWriteQuantiles = Shuffle Write Size / Records | +:
- getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords)
+ val gcQuantiles = titleCell("GC Time", ToolTips.GC_TIME) ++ timeQuantiles(metrics.jvmGcTime)
- val memoryBytesSpilledSizes = validTasks.map(_.taskMetrics.get.memoryBytesSpilled.toDouble)
- val memoryBytesSpilledQuantiles = Shuffle spill (memory) | +:
- getFormattedSizeQuantiles(memoryBytesSpilledSizes)
+ val serializationQuantiles = titleCell("Result Serialization Time",
+ ToolTips.RESULT_SERIALIZATION_TIME) ++ timeQuantiles(metrics.resultSerializationTime)
- val diskBytesSpilledSizes = validTasks.map(_.taskMetrics.get.diskBytesSpilled.toDouble)
- val diskBytesSpilledQuantiles = Shuffle spill (disk) | +:
- getFormattedSizeQuantiles(diskBytesSpilledSizes)
+ val gettingResultQuantiles = titleCell("Getting Result Time", ToolTips.GETTING_RESULT_TIME) ++
+ timeQuantiles(metrics.gettingResultTime)
- val listings: Seq[Seq[Node]] = Seq(
- {serviceQuantiles} ,
- {schedulerDelayQuantiles} ,
-
- {deserializationQuantiles}
-
- {gcQuantiles} ,
-
- {serializationQuantiles}
- ,
- {gettingResultQuantiles} ,
-
- {peakExecutionMemoryQuantiles}
- ,
- if (hasInput(stageData)) {inputQuantiles} else Nil,
- if (hasOutput(stageData)) {outputQuantiles} else Nil,
- if (hasShuffleRead(stageData)) {
-
- {shuffleReadBlockedQuantiles}
-
- {shuffleReadTotalQuantiles}
-
- {shuffleReadRemoteQuantiles}
-
- } else {
- Nil
- },
- if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil,
- if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil,
- if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil)
-
- val quantileHeaders = Seq("Metric", "Min", "25th percentile",
- "Median", "75th percentile", "Max")
- // The summary table does not use CSS to stripe rows, which doesn't work with hidden
- // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows).
- Some(UIUtils.listingTable(
- quantileHeaders,
- identity[Seq[Node]],
- listings,
- fixedWidth = true,
- id = Some("task-summary-table"),
- stripeRowsWithCss = false))
+ val peakExecutionMemoryQuantiles = titleCell("Peak Execution Memory",
+ ToolTips.PEAK_EXECUTION_MEMORY) ++ sizeQuantiles(metrics.peakExecutionMemory)
+
+ // The scheduler delay includes the network delay to send the task to the worker
+ // machine and to send back the result (but not the time to fetch the task result,
+ // if it needed to be fetched from the block manager on the worker).
+ val schedulerDelayQuantiles = titleCell("Scheduler Delay", ToolTips.SCHEDULER_DELAY) ++
+ timeQuantiles(metrics.schedulerDelay)
+
+ def inputQuantiles: Seq[Node] = {
+ simpleTitleCell("Input Size / Records") ++
+ sizeQuantilesWithRecords(metrics.inputMetrics.bytesRead, metrics.inputMetrics.recordsRead)
+ }
+
+ def outputQuantiles: Seq[Node] = {
+ simpleTitleCell("Output Size / Records") ++
+ sizeQuantilesWithRecords(metrics.outputMetrics.bytesWritten,
+ metrics.outputMetrics.recordsWritten)
}
+ def shuffleReadBlockedQuantiles: Seq[Node] = {
+ titleCell("Shuffle Read Blocked Time", ToolTips.SHUFFLE_READ_BLOCKED_TIME) ++
+ timeQuantiles(metrics.shuffleReadMetrics.fetchWaitTime)
+ }
+
+ def shuffleReadTotalQuantiles: Seq[Node] = {
+ titleCell("Shuffle Read Size / Records", ToolTips.SHUFFLE_READ) ++
+ sizeQuantilesWithRecords(metrics.shuffleReadMetrics.readBytes,
+ metrics.shuffleReadMetrics.readRecords)
+ }
+
+ def shuffleReadRemoteQuantiles: Seq[Node] = {
+ titleCell("Shuffle Remote Reads", ToolTips.SHUFFLE_READ_REMOTE_SIZE) ++
+ sizeQuantiles(metrics.shuffleReadMetrics.remoteBytesRead)
+ }
+
+ def shuffleWriteQuantiles: Seq[Node] = {
+ simpleTitleCell("Shuffle Write Size / Records") ++
+ sizeQuantilesWithRecords(metrics.shuffleWriteMetrics.writeBytes,
+ metrics.shuffleWriteMetrics.writeRecords)
+ }
+
+ def memoryBytesSpilledQuantiles: Seq[Node] = {
+ simpleTitleCell("Shuffle spill (memory)") ++ sizeQuantiles(metrics.memoryBytesSpilled)
+ }
+
+ def diskBytesSpilledQuantiles: Seq[Node] = {
+ simpleTitleCell("Shuffle spill (disk)") ++ sizeQuantiles(metrics.diskBytesSpilled)
+ }
+
+ val listings: Seq[Seq[Node]] = Seq(
+ {serviceQuantiles} ,
+ {schedulerDelayQuantiles} ,
+
+ {deserializationQuantiles}
+
+ {gcQuantiles} ,
+
+ {serializationQuantiles}
+ ,
+ {gettingResultQuantiles} ,
+
+ {peakExecutionMemoryQuantiles}
+ ,
+ if (hasInput(stageData)) {inputQuantiles} else Nil,
+ if (hasOutput(stageData)) {outputQuantiles} else Nil,
+ if (hasShuffleRead(stageData)) {
+
+ {shuffleReadBlockedQuantiles}
+
+ {shuffleReadTotalQuantiles}
+
+ {shuffleReadRemoteQuantiles}
+
+ } else {
+ Nil
+ },
+ if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil,
+ if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil,
+ if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil)
+
+ val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile",
+ "Max")
+ // The summary table does not use CSS to stripe rows, which doesn't work with hidden
+ // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows).
+ UIUtils.listingTable(
+ quantileHeaders,
+ identity[Seq[Node]],
+ listings,
+ fixedWidth = true,
+ id = Some("task-summary-table"),
+ stripeRowsWithCss = false)
+ }
+
val executorTable = new ExecutorTable(stageData, parent.store)
val maybeAccumulableTable: Seq[Node] =
- if (hasAccumulators) { Accumulators ++ accumulableTable } else Seq()
+ if (hasAccumulators(stageData)) { Accumulators ++ accumulableTable } else Seq()
val aggMetrics =
taskIdsInPage.contains(t.taskId) },
+ Option(taskTable).map(_.dataSource.tasks).getOrElse(Nil),
currentTime) ++
++
{summaryTable.getOrElse("No tasks have reported metrics yet.")} ++
@@ -593,10 +525,9 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
val serializationTimeProportion = toProportion(serializationTime)
val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L)
val deserializationTimeProportion = toProportion(deserializationTime)
- val gettingResultTime = getGettingResultTime(taskInfo, currentTime)
+ val gettingResultTime = AppStatusUtils.gettingResultTime(taskInfo)
val gettingResultTimeProportion = toProportion(gettingResultTime)
- val schedulerDelay =
- metricsOpt.map(getSchedulerDelay(taskInfo, _, currentTime)).getOrElse(0L)
+ val schedulerDelay = AppStatusUtils.schedulerDelay(taskInfo)
val schedulerDelayProportion = toProportion(schedulerDelay)
val executorOverhead = serializationTime + deserializationTime
@@ -708,7 +639,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
{
if (MAX_TIMELINE_TASKS < tasks.size) {
- This stage has more than the maximum number of tasks that can be shown in the
+ This page has more than the maximum number of tasks that can be shown in the
visualization! Only the most recent {MAX_TIMELINE_TASKS} tasks
(of {tasks.size} total) are shown.
@@ -733,402 +664,49 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
}
-private[ui] object StagePage {
- private[ui] def getGettingResultTime(info: TaskData, currentTime: Long): Long = {
- info.resultFetchStart match {
- case Some(start) =>
- info.duration match {
- case Some(duration) =>
- info.launchTime.getTime() + duration - start.getTime()
-
- case _ =>
- currentTime - start.getTime()
- }
-
- case _ =>
- 0L
- }
- }
-
- private[ui] def getSchedulerDelay(
- info: TaskData,
- metrics: TaskMetrics,
- currentTime: Long): Long = {
- info.duration match {
- case Some(duration) =>
- val executorOverhead = metrics.executorDeserializeTime + metrics.resultSerializationTime
- math.max(
- 0,
- duration - metrics.executorRunTime - executorOverhead -
- getGettingResultTime(info, currentTime))
-
- case _ =>
- // The task is still running and the metrics like executorRunTime are not available.
- 0L
- }
- }
-
-}
-
-private[ui] case class TaskTableRowInputData(inputSortable: Long, inputReadable: String)
-
-private[ui] case class TaskTableRowOutputData(outputSortable: Long, outputReadable: String)
-
-private[ui] case class TaskTableRowShuffleReadData(
- shuffleReadBlockedTimeSortable: Long,
- shuffleReadBlockedTimeReadable: String,
- shuffleReadSortable: Long,
- shuffleReadReadable: String,
- shuffleReadRemoteSortable: Long,
- shuffleReadRemoteReadable: String)
-
-private[ui] case class TaskTableRowShuffleWriteData(
- writeTimeSortable: Long,
- writeTimeReadable: String,
- shuffleWriteSortable: Long,
- shuffleWriteReadable: String)
-
-private[ui] case class TaskTableRowBytesSpilledData(
- memoryBytesSpilledSortable: Long,
- memoryBytesSpilledReadable: String,
- diskBytesSpilledSortable: Long,
- diskBytesSpilledReadable: String)
-
-/**
- * Contains all data that needs for sorting and generating HTML. Using this one rather than
- * TaskData to avoid creating duplicate contents during sorting the data.
- */
-private[ui] class TaskTableRowData(
- val index: Int,
- val taskId: Long,
- val attempt: Int,
- val speculative: Boolean,
- val status: String,
- val taskLocality: String,
- val executorId: String,
- val host: String,
- val launchTime: Long,
- val duration: Long,
- val formatDuration: String,
- val schedulerDelay: Long,
- val taskDeserializationTime: Long,
- val gcTime: Long,
- val serializationTime: Long,
- val gettingResultTime: Long,
- val peakExecutionMemoryUsed: Long,
- val accumulators: Option[String], // HTML
- val input: Option[TaskTableRowInputData],
- val output: Option[TaskTableRowOutputData],
- val shuffleRead: Option[TaskTableRowShuffleReadData],
- val shuffleWrite: Option[TaskTableRowShuffleWriteData],
- val bytesSpilled: Option[TaskTableRowBytesSpilledData],
- val error: String,
- val logs: Map[String, String])
-
private[ui] class TaskDataSource(
- tasks: Seq[TaskData],
- hasAccumulators: Boolean,
- hasInput: Boolean,
- hasOutput: Boolean,
- hasShuffleRead: Boolean,
- hasShuffleWrite: Boolean,
- hasBytesSpilled: Boolean,
+ stage: StageData,
currentTime: Long,
pageSize: Int,
sortColumn: String,
desc: Boolean,
- store: AppStatusStore) extends PagedDataSource[TaskTableRowData](pageSize) {
- import StagePage._
+ store: AppStatusStore) extends PagedDataSource[TaskData](pageSize) {
+ import ApiHelper._
// Keep an internal cache of executor log maps so that long task lists render faster.
private val executorIdToLogs = new HashMap[String, Map[String, String]]()
- // Convert TaskData to TaskTableRowData which contains the final contents to show in the table
- // so that we can avoid creating duplicate contents during sorting the data
- private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc))
+ private var _tasksToShow: Seq[TaskData] = null
- private var _slicedTaskIds: Set[Long] = _
+ override def dataSize: Int = store.taskCount(stage.stageId, stage.attemptId).toInt
- override def dataSize: Int = data.size
-
- override def sliceData(from: Int, to: Int): Seq[TaskTableRowData] = {
- val r = data.slice(from, to)
- _slicedTaskIds = r.map(_.taskId).toSet
- r
- }
-
- def slicedTaskIds: Set[Long] = _slicedTaskIds
-
- private def taskRow(info: TaskData): TaskTableRowData = {
- val metrics = info.taskMetrics
- val duration = info.duration.getOrElse(1L)
- val formatDuration = info.duration.map(d => UIUtils.formatDuration(d)).getOrElse("")
- val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L)
- val gcTime = metrics.map(_.jvmGcTime).getOrElse(0L)
- val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L)
- val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
- val gettingResultTime = getGettingResultTime(info, currentTime)
-
- val externalAccumulableReadable = info.accumulatorUpdates.map { acc =>
- StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}")
+ override def sliceData(from: Int, to: Int): Seq[TaskData] = {
+ if (_tasksToShow == null) {
+ _tasksToShow = store.taskList(stage.stageId, stage.attemptId, from, to - from,
+ indexName(sortColumn), !desc)
}
- val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L)
-
- val maybeInput = metrics.map(_.inputMetrics)
- val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L)
- val inputReadable = maybeInput
- .map(m => s"${Utils.bytesToString(m.bytesRead)}")
- .getOrElse("")
- val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("")
-
- val maybeOutput = metrics.map(_.outputMetrics)
- val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L)
- val outputReadable = maybeOutput
- .map(m => s"${Utils.bytesToString(m.bytesWritten)}")
- .getOrElse("")
- val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("")
-
- val maybeShuffleRead = metrics.map(_.shuffleReadMetrics)
- val shuffleReadBlockedTimeSortable = maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L)
- val shuffleReadBlockedTimeReadable =
- maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("")
-
- val totalShuffleBytes = maybeShuffleRead.map(ApiHelper.totalBytesRead)
- val shuffleReadSortable = totalShuffleBytes.getOrElse(0L)
- val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("")
- val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("")
-
- val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead)
- val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L)
- val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("")
-
- val maybeShuffleWrite = metrics.map(_.shuffleWriteMetrics)
- val shuffleWriteSortable = maybeShuffleWrite.map(_.bytesWritten).getOrElse(0L)
- val shuffleWriteReadable = maybeShuffleWrite
- .map(m => s"${Utils.bytesToString(m.bytesWritten)}").getOrElse("")
- val shuffleWriteRecords = maybeShuffleWrite
- .map(_.recordsWritten.toString).getOrElse("")
-
- val maybeWriteTime = metrics.map(_.shuffleWriteMetrics.writeTime)
- val writeTimeSortable = maybeWriteTime.getOrElse(0L)
- val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms =>
- if (ms == 0) "" else UIUtils.formatDuration(ms)
- }.getOrElse("")
-
- val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled)
- val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.getOrElse(0L)
- val memoryBytesSpilledReadable =
- maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("")
-
- val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled)
- val diskBytesSpilledSortable = maybeDiskBytesSpilled.getOrElse(0L)
- val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("")
-
- val input =
- if (hasInput) {
- Some(TaskTableRowInputData(inputSortable, s"$inputReadable / $inputRecords"))
- } else {
- None
- }
-
- val output =
- if (hasOutput) {
- Some(TaskTableRowOutputData(outputSortable, s"$outputReadable / $outputRecords"))
- } else {
- None
- }
-
- val shuffleRead =
- if (hasShuffleRead) {
- Some(TaskTableRowShuffleReadData(
- shuffleReadBlockedTimeSortable,
- shuffleReadBlockedTimeReadable,
- shuffleReadSortable,
- s"$shuffleReadReadable / $shuffleReadRecords",
- shuffleReadRemoteSortable,
- shuffleReadRemoteReadable
- ))
- } else {
- None
- }
-
- val shuffleWrite =
- if (hasShuffleWrite) {
- Some(TaskTableRowShuffleWriteData(
- writeTimeSortable,
- writeTimeReadable,
- shuffleWriteSortable,
- s"$shuffleWriteReadable / $shuffleWriteRecords"
- ))
- } else {
- None
- }
-
- val bytesSpilled =
- if (hasBytesSpilled) {
- Some(TaskTableRowBytesSpilledData(
- memoryBytesSpilledSortable,
- memoryBytesSpilledReadable,
- diskBytesSpilledSortable,
- diskBytesSpilledReadable
- ))
- } else {
- None
- }
-
- new TaskTableRowData(
- info.index,
- info.taskId,
- info.attempt,
- info.speculative,
- info.status,
- info.taskLocality.toString,
- info.executorId,
- info.host,
- info.launchTime.getTime(),
- duration,
- formatDuration,
- schedulerDelay,
- taskDeserializationTime,
- gcTime,
- serializationTime,
- gettingResultTime,
- peakExecutionMemoryUsed,
- if (hasAccumulators) Some(externalAccumulableReadable.mkString(" ")) else None,
- input,
- output,
- shuffleRead,
- shuffleWrite,
- bytesSpilled,
- info.errorMessage.getOrElse(""),
- executorLogs(info.executorId))
+ _tasksToShow
}
- private def executorLogs(id: String): Map[String, String] = {
+ def tasks: Seq[TaskData] = _tasksToShow
+
+ def executorLogs(id: String): Map[String, String] = {
executorIdToLogs.getOrElseUpdate(id,
store.asOption(store.executorSummary(id)).map(_.executorLogs).getOrElse(Map.empty))
}
- /**
- * Return Ordering according to sortColumn and desc
- */
- private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = {
- val ordering: Ordering[TaskTableRowData] = sortColumn match {
- case "Index" => Ordering.by(_.index)
- case "ID" => Ordering.by(_.taskId)
- case "Attempt" => Ordering.by(_.attempt)
- case "Status" => Ordering.by(_.status)
- case "Locality Level" => Ordering.by(_.taskLocality)
- case "Executor ID" => Ordering.by(_.executorId)
- case "Host" => Ordering.by(_.host)
- case "Launch Time" => Ordering.by(_.launchTime)
- case "Duration" => Ordering.by(_.duration)
- case "Scheduler Delay" => Ordering.by(_.schedulerDelay)
- case "Task Deserialization Time" => Ordering.by(_.taskDeserializationTime)
- case "GC Time" => Ordering.by(_.gcTime)
- case "Result Serialization Time" => Ordering.by(_.serializationTime)
- case "Getting Result Time" => Ordering.by(_.gettingResultTime)
- case "Peak Execution Memory" => Ordering.by(_.peakExecutionMemoryUsed)
- case "Accumulators" =>
- if (hasAccumulators) {
- Ordering.by(_.accumulators.get)
- } else {
- throw new IllegalArgumentException(
- "Cannot sort by Accumulators because of no accumulators")
- }
- case "Input Size / Records" =>
- if (hasInput) {
- Ordering.by(_.input.get.inputSortable)
- } else {
- throw new IllegalArgumentException(
- "Cannot sort by Input Size / Records because of no inputs")
- }
- case "Output Size / Records" =>
- if (hasOutput) {
- Ordering.by(_.output.get.outputSortable)
- } else {
- throw new IllegalArgumentException(
- "Cannot sort by Output Size / Records because of no outputs")
- }
- // ShuffleRead
- case "Shuffle Read Blocked Time" =>
- if (hasShuffleRead) {
- Ordering.by(_.shuffleRead.get.shuffleReadBlockedTimeSortable)
- } else {
- throw new IllegalArgumentException(
- "Cannot sort by Shuffle Read Blocked Time because of no shuffle reads")
- }
- case "Shuffle Read Size / Records" =>
- if (hasShuffleRead) {
- Ordering.by(_.shuffleRead.get.shuffleReadSortable)
- } else {
- throw new IllegalArgumentException(
- "Cannot sort by Shuffle Read Size / Records because of no shuffle reads")
- }
- case "Shuffle Remote Reads" =>
- if (hasShuffleRead) {
- Ordering.by(_.shuffleRead.get.shuffleReadRemoteSortable)
- } else {
- throw new IllegalArgumentException(
- "Cannot sort by Shuffle Remote Reads because of no shuffle reads")
- }
- // ShuffleWrite
- case "Write Time" =>
- if (hasShuffleWrite) {
- Ordering.by(_.shuffleWrite.get.writeTimeSortable)
- } else {
- throw new IllegalArgumentException(
- "Cannot sort by Write Time because of no shuffle writes")
- }
- case "Shuffle Write Size / Records" =>
- if (hasShuffleWrite) {
- Ordering.by(_.shuffleWrite.get.shuffleWriteSortable)
- } else {
- throw new IllegalArgumentException(
- "Cannot sort by Shuffle Write Size / Records because of no shuffle writes")
- }
- // BytesSpilled
- case "Shuffle Spill (Memory)" =>
- if (hasBytesSpilled) {
- Ordering.by(_.bytesSpilled.get.memoryBytesSpilledSortable)
- } else {
- throw new IllegalArgumentException(
- "Cannot sort by Shuffle Spill (Memory) because of no spills")
- }
- case "Shuffle Spill (Disk)" =>
- if (hasBytesSpilled) {
- Ordering.by(_.bytesSpilled.get.diskBytesSpilledSortable)
- } else {
- throw new IllegalArgumentException(
- "Cannot sort by Shuffle Spill (Disk) because of no spills")
- }
- case "Errors" => Ordering.by(_.error)
- case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn")
- }
- if (desc) {
- ordering.reverse
- } else {
- ordering
- }
- }
-
}
private[ui] class TaskPagedTable(
- conf: SparkConf,
+ stage: StageData,
basePath: String,
- data: Seq[TaskData],
- hasAccumulators: Boolean,
- hasInput: Boolean,
- hasOutput: Boolean,
- hasShuffleRead: Boolean,
- hasShuffleWrite: Boolean,
- hasBytesSpilled: Boolean,
currentTime: Long,
pageSize: Int,
sortColumn: String,
desc: Boolean,
- store: AppStatusStore) extends PagedTable[TaskTableRowData] {
+ store: AppStatusStore) extends PagedTable[TaskData] {
+
+ import ApiHelper._
override def tableId: String = "task-table"
@@ -1142,13 +720,7 @@ private[ui] class TaskPagedTable(
override def pageNumberFormField: String = "task.page"
override val dataSource: TaskDataSource = new TaskDataSource(
- data,
- hasAccumulators,
- hasInput,
- hasOutput,
- hasShuffleRead,
- hasShuffleWrite,
- hasBytesSpilled,
+ stage,
currentTime,
pageSize,
sortColumn,
@@ -1170,37 +742,39 @@ private[ui] class TaskPagedTable(
}
def headers: Seq[Node] = {
+ import ApiHelper._
+
val taskHeadersAndCssClasses: Seq[(String, String)] =
Seq(
- ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""),
- ("Executor ID", ""), ("Host", ""), ("Launch Time", ""), ("Duration", ""),
- ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY),
- ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME),
- ("GC Time", ""),
- ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME),
- ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME),
- ("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++
- {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++
- {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++
- {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++
- {if (hasShuffleRead) {
- Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME),
- ("Shuffle Read Size / Records", ""),
- ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE))
+ (HEADER_TASK_INDEX, ""), (HEADER_ID, ""), (HEADER_ATTEMPT, ""), (HEADER_STATUS, ""),
+ (HEADER_LOCALITY, ""), (HEADER_EXECUTOR, ""), (HEADER_HOST, ""), (HEADER_LAUNCH_TIME, ""),
+ (HEADER_DURATION, ""), (HEADER_SCHEDULER_DELAY, TaskDetailsClassNames.SCHEDULER_DELAY),
+ (HEADER_DESER_TIME, TaskDetailsClassNames.TASK_DESERIALIZATION_TIME),
+ (HEADER_GC_TIME, ""),
+ (HEADER_SER_TIME, TaskDetailsClassNames.RESULT_SERIALIZATION_TIME),
+ (HEADER_GETTING_RESULT_TIME, TaskDetailsClassNames.GETTING_RESULT_TIME),
+ (HEADER_PEAK_MEM, TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++
+ {if (hasAccumulators(stage)) Seq((HEADER_ACCUMULATORS, "")) else Nil} ++
+ {if (hasInput(stage)) Seq((HEADER_INPUT_SIZE, "")) else Nil} ++
+ {if (hasOutput(stage)) Seq((HEADER_OUTPUT_SIZE, "")) else Nil} ++
+ {if (hasShuffleRead(stage)) {
+ Seq((HEADER_SHUFFLE_READ_TIME, TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME),
+ (HEADER_SHUFFLE_TOTAL_READS, ""),
+ (HEADER_SHUFFLE_REMOTE_READS, TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE))
} else {
Nil
}} ++
- {if (hasShuffleWrite) {
- Seq(("Write Time", ""), ("Shuffle Write Size / Records", ""))
+ {if (hasShuffleWrite(stage)) {
+ Seq((HEADER_SHUFFLE_WRITE_TIME, ""), (HEADER_SHUFFLE_WRITE_SIZE, ""))
} else {
Nil
}} ++
- {if (hasBytesSpilled) {
- Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", ""))
+ {if (hasBytesSpilled(stage)) {
+ Seq((HEADER_MEM_SPILL, ""), (HEADER_DISK_SPILL, ""))
} else {
Nil
}} ++
- Seq(("Errors", ""))
+ Seq((HEADER_ERROR, ""))
if (!taskHeadersAndCssClasses.map(_._1).contains(sortColumn)) {
throw new IllegalArgumentException(s"Unknown column: $sortColumn")
@@ -1237,7 +811,17 @@ private[ui] class TaskPagedTable(
{headerRow}
}
- def row(task: TaskTableRowData): Seq[Node] = {
+ def row(task: TaskData): Seq[Node] = {
+ def formatDuration(value: Option[Long], hideZero: Boolean = false): String = {
+ value.map { v =>
+ if (v > 0 || !hideZero) UIUtils.formatDuration(v) else ""
+ }.getOrElse("")
+ }
+
+ def formatBytes(value: Option[Long]): String = {
+ Utils.bytesToString(value.getOrElse(0L))
+ }
+
| {task.index} |
{task.taskId} |
@@ -1249,62 +833,102 @@ private[ui] class TaskPagedTable(
{task.host}
{
- task.logs.map {
+ dataSource.executorLogs(task.executorId).map {
case (logName, logUrl) =>
}
}
- {UIUtils.formatDate(new Date(task.launchTime))} |
- {task.formatDuration} |
+ {UIUtils.formatDate(task.launchTime)} |
+ {formatDuration(task.taskMetrics.map(_.executorRunTime))} |
- {UIUtils.formatDuration(task.schedulerDelay)}
+ {UIUtils.formatDuration(AppStatusUtils.schedulerDelay(task))}
|
- {UIUtils.formatDuration(task.taskDeserializationTime)}
+ {formatDuration(task.taskMetrics.map(_.executorDeserializeTime))}
|
- {if (task.gcTime > 0) UIUtils.formatDuration(task.gcTime) else ""}
+ {formatDuration(task.taskMetrics.map(_.jvmGcTime), hideZero = true)}
|
- {UIUtils.formatDuration(task.serializationTime)}
+ {formatDuration(task.taskMetrics.map(_.resultSerializationTime))}
|
- {UIUtils.formatDuration(task.gettingResultTime)}
+ {UIUtils.formatDuration(AppStatusUtils.gettingResultTime(task))}
|
- {Utils.bytesToString(task.peakExecutionMemoryUsed)}
+ {formatBytes(task.taskMetrics.map(_.peakExecutionMemory))}
|
- {if (task.accumulators.nonEmpty) {
- {Unparsed(task.accumulators.get)} |
+ {if (hasAccumulators(stage)) {
+ {accumulatorsInfo(task)} |
}}
- {if (task.input.nonEmpty) {
- {task.input.get.inputReadable} |
+ {if (hasInput(stage)) {
+ metricInfo(task) { m =>
+ val bytesRead = Utils.bytesToString(m.inputMetrics.bytesRead)
+ val records = m.inputMetrics.recordsRead
+ {bytesRead} / {records} |
+ }
}}
- {if (task.output.nonEmpty) {
- {task.output.get.outputReadable} |
+ {if (hasOutput(stage)) {
+ metricInfo(task) { m =>
+ val bytesWritten = Utils.bytesToString(m.outputMetrics.bytesWritten)
+ val records = m.outputMetrics.recordsWritten
+ {bytesWritten} / {records} |
+ }
}}
- {if (task.shuffleRead.nonEmpty) {
+ {if (hasShuffleRead(stage)) {
- {task.shuffleRead.get.shuffleReadBlockedTimeReadable}
+ {formatDuration(task.taskMetrics.map(_.shuffleReadMetrics.fetchWaitTime))}
|
- {task.shuffleRead.get.shuffleReadReadable} |
+ {
+ metricInfo(task) { m =>
+ val bytesRead = Utils.bytesToString(totalBytesRead(m.shuffleReadMetrics))
+ val records = m.shuffleReadMetrics.recordsRead
+ Unparsed(s"$bytesRead / $records")
+ }
+ } |
- {task.shuffleRead.get.shuffleReadRemoteReadable}
+ {formatBytes(task.taskMetrics.map(_.shuffleReadMetrics.remoteBytesRead))}
|
}}
- {if (task.shuffleWrite.nonEmpty) {
- {task.shuffleWrite.get.writeTimeReadable} |
- {task.shuffleWrite.get.shuffleWriteReadable} |
+ {if (hasShuffleWrite(stage)) {
+ {
+ formatDuration(
+ task.taskMetrics.map { m =>
+ TimeUnit.NANOSECONDS.toMillis(m.shuffleWriteMetrics.writeTime)
+ },
+ hideZero = true)
+ } |
+ {
+ metricInfo(task) { m =>
+ val bytesWritten = Utils.bytesToString(m.shuffleWriteMetrics.bytesWritten)
+ val records = m.shuffleWriteMetrics.recordsWritten
+ Unparsed(s"$bytesWritten / $records")
+ }
+ } |
}}
- {if (task.bytesSpilled.nonEmpty) {
- {task.bytesSpilled.get.memoryBytesSpilledReadable} |
- {task.bytesSpilled.get.diskBytesSpilledReadable} |
+ {if (hasBytesSpilled(stage)) {
+ {formatBytes(task.taskMetrics.map(_.memoryBytesSpilled))} |
+ {formatBytes(task.taskMetrics.map(_.diskBytesSpilled))} |
}}
- {errorMessageCell(task.error)}
+ {errorMessageCell(task.errorMessage.getOrElse(""))}
}
+ private def accumulatorsInfo(task: TaskData): Seq[Node] = {
+ task.accumulatorUpdates.flatMap { acc =>
+ if (acc.name != null && acc.update.isDefined) {
+ Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")) ++
+ } else {
+ Nil
+ }
+ }
+ }
+
+ private def metricInfo(task: TaskData)(fn: TaskMetrics => Seq[Node]): Seq[Node] = {
+ task.taskMetrics.map(fn).getOrElse(Nil)
+ }
+
private def errorMessageCell(error: String): Seq[Node] = {
val isMultiline = error.indexOf('\n') >= 0
// Display the first line by default
@@ -1331,7 +955,68 @@ private[ui] class TaskPagedTable(
}
}
-private object ApiHelper {
+private[ui] object ApiHelper {
+
+ val HEADER_ID = "ID"
+ val HEADER_TASK_INDEX = "Index"
+ val HEADER_ATTEMPT = "Attempt"
+ val HEADER_STATUS = "Status"
+ val HEADER_LOCALITY = "Locality Level"
+ val HEADER_EXECUTOR = "Executor ID"
+ val HEADER_HOST = "Host"
+ val HEADER_LAUNCH_TIME = "Launch Time"
+ val HEADER_DURATION = "Duration"
+ val HEADER_SCHEDULER_DELAY = "Scheduler Delay"
+ val HEADER_DESER_TIME = "Task Deserialization Time"
+ val HEADER_GC_TIME = "GC Time"
+ val HEADER_SER_TIME = "Result Serialization Time"
+ val HEADER_GETTING_RESULT_TIME = "Getting Result Time"
+ val HEADER_PEAK_MEM = "Peak Execution Memory"
+ val HEADER_ACCUMULATORS = "Accumulators"
+ val HEADER_INPUT_SIZE = "Input Size / Records"
+ val HEADER_OUTPUT_SIZE = "Output Size / Records"
+ val HEADER_SHUFFLE_READ_TIME = "Shuffle Read Blocked Time"
+ val HEADER_SHUFFLE_TOTAL_READS = "Shuffle Read Size / Records"
+ val HEADER_SHUFFLE_REMOTE_READS = "Shuffle Remote Reads"
+ val HEADER_SHUFFLE_WRITE_TIME = "Write Time"
+ val HEADER_SHUFFLE_WRITE_SIZE = "Shuffle Write Size / Records"
+ val HEADER_MEM_SPILL = "Shuffle Spill (Memory)"
+ val HEADER_DISK_SPILL = "Shuffle Spill (Disk)"
+ val HEADER_ERROR = "Errors"
+
+ private[ui] val COLUMN_TO_INDEX = Map(
+ HEADER_ID -> null.asInstanceOf[String],
+ HEADER_TASK_INDEX -> TaskIndexNames.TASK_INDEX,
+ HEADER_ATTEMPT -> TaskIndexNames.ATTEMPT,
+ HEADER_STATUS -> TaskIndexNames.STATUS,
+ HEADER_LOCALITY -> TaskIndexNames.LOCALITY,
+ HEADER_EXECUTOR -> TaskIndexNames.EXECUTOR,
+ HEADER_HOST -> TaskIndexNames.HOST,
+ HEADER_LAUNCH_TIME -> TaskIndexNames.LAUNCH_TIME,
+ // SPARK-26109: Duration of task as executorRunTime to make it consistent with the
+ // aggregated tasks summary metrics table and the previous versions of Spark.
+ HEADER_DURATION -> TaskIndexNames.EXEC_RUN_TIME,
+ HEADER_SCHEDULER_DELAY -> TaskIndexNames.SCHEDULER_DELAY,
+ HEADER_DESER_TIME -> TaskIndexNames.DESER_TIME,
+ HEADER_GC_TIME -> TaskIndexNames.GC_TIME,
+ HEADER_SER_TIME -> TaskIndexNames.SER_TIME,
+ HEADER_GETTING_RESULT_TIME -> TaskIndexNames.GETTING_RESULT_TIME,
+ HEADER_PEAK_MEM -> TaskIndexNames.PEAK_MEM,
+ HEADER_ACCUMULATORS -> TaskIndexNames.ACCUMULATORS,
+ HEADER_INPUT_SIZE -> TaskIndexNames.INPUT_SIZE,
+ HEADER_OUTPUT_SIZE -> TaskIndexNames.OUTPUT_SIZE,
+ HEADER_SHUFFLE_READ_TIME -> TaskIndexNames.SHUFFLE_READ_TIME,
+ HEADER_SHUFFLE_TOTAL_READS -> TaskIndexNames.SHUFFLE_TOTAL_READS,
+ HEADER_SHUFFLE_REMOTE_READS -> TaskIndexNames.SHUFFLE_REMOTE_READS,
+ HEADER_SHUFFLE_WRITE_TIME -> TaskIndexNames.SHUFFLE_WRITE_TIME,
+ HEADER_SHUFFLE_WRITE_SIZE -> TaskIndexNames.SHUFFLE_WRITE_SIZE,
+ HEADER_MEM_SPILL -> TaskIndexNames.MEM_SPILL,
+ HEADER_DISK_SPILL -> TaskIndexNames.DISK_SPILL,
+ HEADER_ERROR -> TaskIndexNames.ERROR)
+
+ def hasAccumulators(stageData: StageData): Boolean = {
+ stageData.accumulatorUpdates.exists { acc => acc.name != null && acc.value != null }
+ }
def hasInput(stageData: StageData): Boolean = stageData.inputBytes > 0
@@ -1349,4 +1034,16 @@ private object ApiHelper {
metrics.localBytesRead + metrics.remoteBytesRead
}
+ def indexName(sortColumn: String): Option[String] = {
+ COLUMN_TO_INDEX.get(sortColumn) match {
+ case Some(v) => Option(v)
+ case _ => throw new IllegalArgumentException(s"Invalid sort column: $sortColumn")
+ }
+ }
+
+ def lastStageNameAndDescription(store: AppStatusStore, job: JobData): (String, String) = {
+ val stage = store.asOption(store.stageAttempt(job.stageIds.max, 0))
+ (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name))
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 18a4926f2f6c..f001a01de395 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -43,7 +43,9 @@ private[ui] class StageTableBase(
killEnabled: Boolean,
isFailedStage: Boolean) {
// stripXSS is called to remove suspicious characters used in XSS attacks
- val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS))
+ val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) =>
+ UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq
+ }
val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag))
.map(para => para._1 + "=" + para._2(0))
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
index be05a963f0e6..10b032084ce4 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
@@ -37,10 +37,10 @@ private[ui] class StagesTab(val parent: SparkUI, val store: AppStatusStore)
attachPage(new PoolPage(this))
def isFairScheduler: Boolean = {
- store.environmentInfo().sparkProperties.toMap
- .get("spark.scheduler.mode")
- .map { mode => mode == SchedulingMode.FAIR }
- .getOrElse(false)
+ store
+ .environmentInfo()
+ .sparkProperties
+ .contains(("spark.scheduler.mode", SchedulingMode.FAIR.toString))
}
def handleKillRequest(request: HttpServletRequest): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
index 827a8637b9bd..948858224d72 100644
--- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
+++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
@@ -116,7 +116,7 @@ private[spark] object RDDOperationGraph extends Logging {
// Use a special prefix here to differentiate this cluster from other operation clusters
val stageClusterId = STAGE_CLUSTER_PREFIX + stage.stageId
val stageClusterName = s"Stage ${stage.stageId}" +
- { if (stage.attemptId == 0) "" else s" (attempt ${stage.attemptId})" }
+ { if (stage.attemptNumber == 0) "" else s" (attempt ${stage.attemptNumber})" }
val rootCluster = new RDDOperationCluster(stageClusterId, stageClusterName)
var rootNodeCount = 0
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
index 02cee7f8c5b3..2674b9291203 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
@@ -23,7 +23,7 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.{Node, Unparsed}
import org.apache.spark.status.AppStatusStore
-import org.apache.spark.status.api.v1.{RDDDataDistribution, RDDPartitionInfo}
+import org.apache.spark.status.api.v1.{ExecutorSummary, RDDDataDistribution, RDDPartitionInfo}
import org.apache.spark.ui._
import org.apache.spark.util.Utils
@@ -76,7 +76,8 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web
rddStorageInfo.partitions.get,
blockPageSize,
blockSortColumn,
- blockSortDesc)
+ blockSortDesc,
+ store.executorList(true))
_blockTable.table(page)
} catch {
case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) =>
@@ -182,7 +183,8 @@ private[ui] class BlockDataSource(
rddPartitions: Seq[RDDPartitionInfo],
pageSize: Int,
sortColumn: String,
- desc: Boolean) extends PagedDataSource[BlockTableRowData](pageSize) {
+ desc: Boolean,
+ executorIdToAddress: Map[String, String]) extends PagedDataSource[BlockTableRowData](pageSize) {
private val data = rddPartitions.map(blockRow).sorted(ordering(sortColumn, desc))
@@ -198,7 +200,10 @@ private[ui] class BlockDataSource(
rddPartition.storageLevel,
rddPartition.memoryUsed,
rddPartition.diskUsed,
- rddPartition.executors.mkString(" "))
+ rddPartition.executors
+ .map { id => executorIdToAddress.get(id).getOrElse(id) }
+ .sorted
+ .mkString(" "))
}
/**
@@ -226,7 +231,8 @@ private[ui] class BlockPagedTable(
rddPartitions: Seq[RDDPartitionInfo],
pageSize: Int,
sortColumn: String,
- desc: Boolean) extends PagedTable[BlockTableRowData] {
+ desc: Boolean,
+ executorSummaries: Seq[ExecutorSummary]) extends PagedTable[BlockTableRowData] {
override def tableId: String = "rdd-storage-by-block-table"
@@ -243,7 +249,8 @@ private[ui] class BlockPagedTable(
rddPartitions,
pageSize,
sortColumn,
- desc)
+ desc,
+ executorSummaries.map { ex => (ex.id, ex.hostPort) }.toMap)
override def pageLink(page: Int): String = {
val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8")
diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
index f4a736d6d439..bf618b4afbce 100644
--- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
@@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong
import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext}
+import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.AccumulableInfo
private[spark] case class AccumulatorMetadata(
@@ -199,10 +200,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
}
override def toString: String = {
+ // getClass.getSimpleName can cause Malformed class name error,
+ // call safer `Utils.getSimpleName` instead
if (metadata == null) {
- "Un-registered Accumulator: " + getClass.getSimpleName
+ "Un-registered Accumulator: " + Utils.getSimpleName(getClass)
} else {
- getClass.getSimpleName + s"(id: $id, name: $name, value: $value)"
+ Utils.getSimpleName(getClass) + s"(id: $id, name: $name, value: $value)"
}
}
}
@@ -211,7 +214,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
/**
* An internal class used to track accumulators by Spark itself.
*/
-private[spark] object AccumulatorContext {
+private[spark] object AccumulatorContext extends Logging {
/**
* This global map holds the original accumulator objects that are created on the driver.
@@ -258,13 +261,16 @@ private[spark] object AccumulatorContext {
* Returns the [[AccumulatorV2]] registered with the given ID, if any.
*/
def get(id: Long): Option[AccumulatorV2[_, _]] = {
- Option(originals.get(id)).map { ref =>
- // Since we are storing weak references, we must check whether the underlying data is valid.
+ val ref = originals.get(id)
+ if (ref eq null) {
+ None
+ } else {
+ // Since we are storing weak references, warn when the underlying data is not valid.
val acc = ref.get
if (acc eq null) {
- throw new IllegalStateException(s"Attempted to access garbage collected accumulator $id")
+ logWarning(s"Attempted to access garbage collected accumulator $id")
}
- acc
+ Option(acc)
}
}
@@ -290,7 +296,8 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
private var _count = 0L
/**
- * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * Returns false if this accumulator has had any values added to it or the sum is non-zero.
+ *
* @since 2.0.0
*/
override def isZero: Boolean = _sum == 0L && _count == 0
@@ -368,6 +375,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
private var _sum = 0.0
private var _count = 0L
+ /**
+ * Returns false if this accumulator has had any values added to it or the sum is non-zero.
+ */
override def isZero: Boolean = _sum == 0.0 && _count == 0
override def copy(): DoubleAccumulator = {
@@ -441,6 +451,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {
private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]())
+ /**
+ * Returns false if this accumulator instance has any values in it.
+ */
override def isZero: Boolean = _list.isEmpty
override def copyAndReset(): CollectionAccumulator[T] = new CollectionAccumulator
@@ -479,7 +492,9 @@ class LegacyAccumulatorWrapper[R, T](
param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] {
private[spark] var _value = initialValue // Current value on driver
- override def isZero: Boolean = _value == param.zero(initialValue)
+ @transient private lazy val _zero = param.zero(initialValue)
+
+ override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef])
override def copy(): LegacyAccumulatorWrapper[R, T] = {
val acc = new LegacyAccumulatorWrapper(initialValue, param)
@@ -488,7 +503,7 @@ class LegacyAccumulatorWrapper[R, T](
}
override def reset(): Unit = {
- _value = param.zero(initialValue)
+ _value = _zero
}
override def add(v: T): Unit = _value = param.addAccumulator(_value, v)
diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
index 31d230d0fec8..21acaa95c564 100644
--- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
+++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
@@ -22,9 +22,7 @@ package org.apache.spark.util
* through all the elements.
*/
private[spark]
-// scalastyle:off
abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] {
-// scalastyle:on
private[this] var completed = false
def next(): A = sub.next()
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 5e60218c5740..ff83301d631c 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -263,7 +263,7 @@ private[spark] object JsonProtocol {
val completionTime = stageInfo.completionTime.map(JInt(_)).getOrElse(JNothing)
val failureReason = stageInfo.failureReason.map(JString(_)).getOrElse(JNothing)
("Stage ID" -> stageInfo.stageId) ~
- ("Stage Attempt ID" -> stageInfo.attemptId) ~
+ ("Stage Attempt ID" -> stageInfo.attemptNumber) ~
("Stage Name" -> stageInfo.name) ~
("Number of Tasks" -> stageInfo.numTasks) ~
("RDD Info" -> rddInfo) ~
diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
index 76a56298aaeb..4a7798434680 100644
--- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
@@ -60,6 +60,15 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
}
}
+ /**
+ * This can be overriden by subclasses if there is any extra cleanup to do when removing a
+ * listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus.
+ */
+ def removeListenerOnError(listener: L): Unit = {
+ removeListener(listener)
+ }
+
+
/**
* Post the event to all registered listeners. The `postToAll` caller should guarantee calling
* `postToAll` in the same thread for all events.
@@ -80,7 +89,16 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
}
try {
doPostEvent(listener, event)
+ if (Thread.interrupted()) {
+ // We want to throw the InterruptedException right away so we can associate the interrupt
+ // with this listener, as opposed to waiting for a queue.take() etc. to detect it.
+ throw new InterruptedException()
+ }
} catch {
+ case ie: InterruptedException =>
+ logError(s"Interrupted while posting to ${Utils.getFormattedClassName(listener)}. " +
+ s"Removing that listener.", ie)
+ removeListenerOnError(listener)
case NonFatal(e) =>
logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e)
} finally {
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 585330297314..9197c23df709 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,6 +18,8 @@
package org.apache.spark.util
import java.io._
+import java.lang.{Byte => JByte}
+import java.lang.InternalError
import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
import java.lang.reflect.InvocationTargetException
import java.math.{MathContext, RoundingMode}
@@ -26,6 +28,7 @@ import java.nio.ByteBuffer
import java.nio.channels.{Channels, FileChannel}
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}
+import java.security.SecureRandom
import java.util.{Locale, Properties, Random, UUID}
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
@@ -44,6 +47,7 @@ import scala.util.matching.Regex
import _root_.io.netty.channel.unix.Errors.NativeIoException
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
+import com.google.common.hash.HashCodes
import com.google.common.io.{ByteStreams, Files => GFiles}
import com.google.common.net.InetAddresses
import org.apache.commons.lang3.SystemUtils
@@ -1102,7 +1106,7 @@ private[spark] object Utils extends Logging {
}
/**
- * Convert a time parameter such as (50s, 100ms, or 250us) to microseconds for internal use. If
+ * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If
* no suffix is provided, the passed number is assumed to be in ms.
*/
def timeStringAsMs(str: String): Long = {
@@ -1872,7 +1876,7 @@ private[spark] object Utils extends Logging {
/** Return the class name of the given object, removing all dollar signs */
def getFormattedClassName(obj: AnyRef): String = {
- obj.getClass.getSimpleName.replace("$", "")
+ getSimpleName(obj.getClass).replace("$", "")
}
/** Return an option that translates JNothing to None */
@@ -2805,6 +2809,71 @@ private[spark] object Utils extends Logging {
s"k8s://$resolvedURL"
}
+
+ def createSecret(conf: SparkConf): String = {
+ val bits = conf.get(AUTH_SECRET_BIT_LENGTH)
+ val rnd = new SecureRandom()
+ val secretBytes = new Array[Byte](bits / JByte.SIZE)
+ rnd.nextBytes(secretBytes)
+ HashCodes.fromBytes(secretBytes).toString()
+ }
+
+ /**
+ * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala.
+ * This method mimicks scalatest's getSimpleNameOfAnObjectsClass.
+ */
+ def getSimpleName(cls: Class[_]): String = {
+ try {
+ return cls.getSimpleName
+ } catch {
+ case err: InternalError => return stripDollars(stripPackages(cls.getName))
+ }
+ }
+
+ /**
+ * Remove the packages from full qualified class name
+ */
+ private def stripPackages(fullyQualifiedName: String): String = {
+ fullyQualifiedName.split("\\.").takeRight(1)(0)
+ }
+
+ /**
+ * Remove trailing dollar signs from qualified class name,
+ * and return the trailing part after the last dollar sign in the middle
+ */
+ private def stripDollars(s: String): String = {
+ val lastDollarIndex = s.lastIndexOf('$')
+ if (lastDollarIndex < s.length - 1) {
+ // The last char is not a dollar sign
+ if (lastDollarIndex == -1 || !s.contains("$iw")) {
+ // The name does not have dollar sign or is not an intepreter
+ // generated class, so we should return the full string
+ s
+ } else {
+ // The class name is intepreter generated,
+ // return the part after the last dollar sign
+ // This is the same behavior as getClass.getSimpleName
+ s.substring(lastDollarIndex + 1)
+ }
+ }
+ else {
+ // The last char is a dollar sign
+ // Find last non-dollar char
+ val lastNonDollarChar = s.reverse.find(_ != '$')
+ lastNonDollarChar match {
+ case None => s
+ case Some(c) =>
+ val lastNonDollarIndex = s.lastIndexOf(c)
+ if (lastNonDollarIndex == -1) {
+ s
+ } else {
+ // Strip the trailing dollar signs
+ // Invoke stripDollars again to get the simple name
+ stripDollars(s.substring(0, lastNonDollarIndex + 1))
+ }
+ }
+ }
+ }
}
private[util] object CallerContext extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 375f4a692122..5c6dd45ec58e 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -463,7 +463,7 @@ class ExternalAppendOnlyMap[K, V, C](
// An intermediate stream that reads from exactly one batch
// This guards against pre-fetching and other arbitrary behavior of higher level streams
- private var deserializeStream = nextBatchStream()
+ private var deserializeStream: DeserializationStream = null
private var nextItem: (K, C) = null
private var objectsRead = 0
@@ -528,7 +528,11 @@ class ExternalAppendOnlyMap[K, V, C](
override def hasNext: Boolean = {
if (nextItem == null) {
if (deserializeStream == null) {
- return false
+ // In case of deserializeStream has not been initialized
+ deserializeStream = nextBatchStream()
+ if (deserializeStream == null) {
+ return false
+ }
}
nextItem = readNextItem()
}
@@ -536,19 +540,18 @@ class ExternalAppendOnlyMap[K, V, C](
}
override def next(): (K, C) = {
- val item = if (nextItem == null) readNextItem() else nextItem
- if (item == null) {
+ if (!hasNext) {
throw new NoSuchElementException
}
+ val item = nextItem
nextItem = null
item
}
private def cleanup() {
batchIndex = batchOffsets.length // Prevent reading any other batch
- val ds = deserializeStream
- if (ds != null) {
- ds.close()
+ if (deserializeStream != null) {
+ deserializeStream.close()
deserializeStream = null
}
if (fileStream != null) {
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index 7367af7888bd..3ae8dfcc1cb6 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -63,10 +63,15 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
*/
def writeFully(channel: WritableByteChannel): Unit = {
for (bytes <- getChunks()) {
- while (bytes.remaining() > 0) {
- val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize)
- bytes.limit(bytes.position() + ioSize)
- channel.write(bytes)
+ val curChunkLimit = bytes.limit()
+ while (bytes.hasRemaining) {
+ try {
+ val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize)
+ bytes.limit(bytes.position() + ioSize)
+ channel.write(bytes)
+ } finally {
+ bytes.limit(curChunkLimit)
+ }
}
}
}
diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
index c2261c204cd4..2225591a4ff7 100644
--- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
+++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.launcher;
+import java.time.Duration;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.HashMap;
@@ -30,6 +31,7 @@
import static org.mockito.Mockito.*;
import org.apache.spark.SparkContext;
+import org.apache.spark.SparkContext$;
import org.apache.spark.internal.config.package$;
import org.apache.spark.util.Utils;
@@ -133,6 +135,12 @@ public void testInProcessLauncher() throws Exception {
p.put(e.getKey(), e.getValue());
}
System.setProperties(p);
+ // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet.
+ // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM.
+ // See SPARK-23019 and SparkContext.stop() for details.
+ eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> {
+ assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty());
+ });
}
}
@@ -141,26 +149,47 @@ private void inProcessLauncherTestImpl() throws Exception {
SparkAppHandle.Listener listener = mock(SparkAppHandle.Listener.class);
doAnswer(invocation -> {
SparkAppHandle h = (SparkAppHandle) invocation.getArguments()[0];
- transitions.add(h.getState());
+ synchronized (transitions) {
+ transitions.add(h.getState());
+ }
return null;
}).when(listener).stateChanged(any(SparkAppHandle.class));
- SparkAppHandle handle = new InProcessLauncher()
- .setMaster("local")
- .setAppResource(SparkLauncher.NO_RESOURCE)
- .setMainClass(InProcessTestApp.class.getName())
- .addAppArgs("hello")
- .startApplication(listener);
-
- waitFor(handle);
- assertEquals(SparkAppHandle.State.FINISHED, handle.getState());
-
- // Matches the behavior of LocalSchedulerBackend.
- List expected = Arrays.asList(
- SparkAppHandle.State.CONNECTED,
- SparkAppHandle.State.RUNNING,
- SparkAppHandle.State.FINISHED);
- assertEquals(expected, transitions);
+ SparkAppHandle handle = null;
+ try {
+ synchronized (InProcessTestApp.LOCK) {
+ handle = new InProcessLauncher()
+ .setMaster("local")
+ .setAppResource(SparkLauncher.NO_RESOURCE)
+ .setMainClass(InProcessTestApp.class.getName())
+ .addAppArgs("hello")
+ .startApplication(listener);
+
+ // SPARK-23020: see doc for InProcessTestApp.LOCK for a description of the race. Here
+ // we wait until we know that the connection between the app and the launcher has been
+ // established before allowing the app to finish.
+ final SparkAppHandle _handle = handle;
+ eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> {
+ assertNotEquals(SparkAppHandle.State.UNKNOWN, _handle.getState());
+ });
+
+ InProcessTestApp.LOCK.wait(5000);
+ }
+
+ waitFor(handle);
+ assertEquals(SparkAppHandle.State.FINISHED, handle.getState());
+
+ // Matches the behavior of LocalSchedulerBackend.
+ List expected = Arrays.asList(
+ SparkAppHandle.State.CONNECTED,
+ SparkAppHandle.State.RUNNING,
+ SparkAppHandle.State.FINISHED);
+ assertEquals(expected, transitions);
+ } finally {
+ if (handle != null) {
+ handle.kill();
+ }
+ }
}
public static class SparkLauncherTestApp {
@@ -176,10 +205,26 @@ public static void main(String[] args) throws Exception {
public static class InProcessTestApp {
+ /**
+ * SPARK-23020: there's a race caused by a child app finishing too quickly. This would cause
+ * the InProcessAppHandle to dispose of itself even before the child connection was properly
+ * established, so no state changes would be detected for the application and its final
+ * state would be LOST.
+ *
+ * It's not really possible to fix that race safely in the handle code itself without changing
+ * the way in-process apps talk to the launcher library, so we work around that in the test by
+ * synchronizing on this object.
+ */
+ public static final Object LOCK = new Object();
+
public static void main(String[] args) throws Exception {
assertNotEquals(0, args.length);
assertEquals(args[0], "hello");
new SparkContext().stop();
+
+ synchronized (LOCK) {
+ LOCK.notifyAll();
+ }
}
}
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index 46b0516e3614..a0664b30d6cc 100644
--- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -21,6 +21,7 @@
import org.junit.Test;
import org.apache.spark.SparkConf;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.MemoryBlock;
public class TaskMemoryManagerSuite {
@@ -68,6 +69,34 @@ public void encodePageNumberAndOffsetOnHeap() {
Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
}
+ @Test
+ public void freeingPageSetsPageNumberToSpecialConstant() {
+ final TaskMemoryManager manager = new TaskMemoryManager(
+ new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+ final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
+ final MemoryBlock dataPage = manager.allocatePage(256, c);
+ c.freePage(dataPage);
+ Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber);
+ }
+
+ @Test(expected = AssertionError.class)
+ public void freeingPageDirectlyInAllocatorTriggersAssertionError() {
+ final TaskMemoryManager manager = new TaskMemoryManager(
+ new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+ final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
+ final MemoryBlock dataPage = manager.allocatePage(256, c);
+ MemoryAllocator.HEAP.free(dataPage);
+ }
+
+ @Test(expected = AssertionError.class)
+ public void callingFreePageOnDirectlyAllocatedPageTriggersAssertionError() {
+ final TaskMemoryManager manager = new TaskMemoryManager(
+ new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+ final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
+ final MemoryBlock dataPage = MemoryAllocator.HEAP.allocate(256);
+ manager.freePage(dataPage, c);
+ }
+
@Test
public void cooperativeSpilling() {
final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf());
diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
index db91329c94cb..1b7739cce6fb 100644
--- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
+++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
@@ -19,6 +19,8 @@
import java.io.IOException;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
public class TestMemoryConsumer extends MemoryConsumer {
public TestMemoryConsumer(TaskMemoryManager memoryManager, MemoryMode mode) {
super(memoryManager, 1024L, mode);
@@ -43,6 +45,11 @@ void free(long size) {
used -= size;
taskMemoryManager.releaseExecutionMemory(size, this);
}
+
+ public void freePage(MemoryBlock page) {
+ used -= page.size();
+ taskMemoryManager.freePage(page, this);
+ }
}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index af4975c888d6..411cd5cb5733 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -72,8 +72,10 @@ public class UnsafeExternalSorterSuite {
public int compare(
Object leftBaseObject,
long leftBaseOffset,
+ int leftBaseLength,
Object rightBaseObject,
- long rightBaseOffset) {
+ long rightBaseOffset,
+ int rightBaseLength) {
return 0;
}
};
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 594f07dd780f..85ffdca436e1 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -98,8 +98,10 @@ public void testSortingOnlyByIntegerPrefix() throws Exception {
public int compare(
Object leftBaseObject,
long leftBaseOffset,
+ int leftBaseLength,
Object rightBaseObject,
- long rightBaseOffset) {
+ long rightBaseOffset,
+ int rightBaseLength) {
return 0;
}
};
@@ -127,7 +129,6 @@ public int compare(
final UnsafeSorterIterator iter = sorter.getSortedIterator();
int iterLength = 0;
long prevPrefix = -1;
- Arrays.sort(dataToSort);
while (iter.hasNext()) {
iter.loadNext();
final String str =
@@ -164,8 +165,10 @@ public void freeAfterOOM() {
public int compare(
Object leftBaseObject,
long leftBaseOffset,
+ int leftBaseLength,
Object rightBaseObject,
- long rightBaseOffset) {
+ long rightBaseOffset,
+ int rightBaseLength) {
return 0;
}
};
diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
index 94f5805853e1..f8e233a05a44 100644
--- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
+++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
@@ -38,6 +38,7 @@ public static void test() {
tc.attemptNumber();
tc.partitionId();
tc.stageId();
+ tc.stageAttemptNumber();
tc.taskAttemptId();
}
@@ -51,6 +52,7 @@ public void onTaskCompletion(TaskContext context) {
context.isCompleted();
context.isInterrupted();
context.stageId();
+ context.stageAttemptNumber();
context.partitionId();
context.addTaskCompletionListener(this);
}
diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json
index f8e27703c0de..5c42ac1d87f4 100644
--- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json
@@ -7,6 +7,9 @@
"resultSize" : [ 2010.0, 2065.0, 2065.0 ],
"jvmGcTime" : [ 0.0, 0.0, 7.0 ],
"resultSerializationTime" : [ 0.0, 0.0, 2.0 ],
+ "gettingResultTime" : [ 0.0, 0.0, 0.0 ],
+ "schedulerDelay" : [ 2.0, 6.0, 53.0 ],
+ "peakExecutionMemory" : [ 0.0, 0.0, 0.0 ],
"memoryBytesSpilled" : [ 0.0, 0.0, 0.0 ],
"diskBytesSpilled" : [ 0.0, 0.0, 0.0 ],
"inputMetrics" : {
diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json
index a28bda16a956..e6b705989cc9 100644
--- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json
@@ -7,6 +7,9 @@
"resultSize" : [ 1034.0, 1034.0, 1034.0, 1034.0, 1034.0 ],
"jvmGcTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ],
"resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ "gettingResultTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ "schedulerDelay" : [ 4.0, 4.0, 6.0, 7.0, 9.0 ],
+ "peakExecutionMemory" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ],
"memoryBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ],
"diskBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ],
"inputMetrics" : {
diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json
index ede3eaed1d1d..788f28cf7b36 100644
--- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json
@@ -7,6 +7,9 @@
"resultSize" : [ 2010.0, 2065.0, 2065.0, 2065.0, 2065.0 ],
"jvmGcTime" : [ 0.0, 0.0, 0.0, 5.0, 7.0 ],
"resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 1.0 ],
+ "gettingResultTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ "schedulerDelay" : [ 2.0, 4.0, 6.0, 13.0, 40.0 ],
+ "peakExecutionMemory" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ],
"memoryBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ],
"diskBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ],
"inputMetrics" : {
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 3990ee1ec326..5d0ffd92647b 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -209,10 +209,8 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
System.gc()
assert(ref.get.isEmpty)
- // Getting a garbage collected accum should throw error
- intercept[IllegalStateException] {
- AccumulatorContext.get(accId)
- }
+ // Getting a garbage collected accum should return None.
+ assert(AccumulatorContext.get(accId).isEmpty)
// Getting a normal accumulator. Note: this has to be separate because referencing an
// accumulator above in an `assert` would keep it from being garbage collected.
diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala
index 91355f736290..a5bdc9579072 100644
--- a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala
+++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala
@@ -103,8 +103,11 @@ class DebugFilesystem extends LocalFileSystem {
override def markSupported(): Boolean = wrapped.markSupported()
override def close(): Unit = {
- wrapped.close()
- removeOpenStream(wrapped)
+ try {
+ wrapped.close()
+ } finally {
+ removeOpenStream(wrapped)
+ }
}
override def read(): Int = wrapped.read()
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index a0cae5a9e011..784beace9000 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark
import scala.collection.mutable
+import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.Mockito.{mock, never, verify, when}
import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
import org.apache.spark.executor.TaskMetrics
@@ -26,6 +28,7 @@ import org.apache.spark.scheduler._
import org.apache.spark.scheduler.ExternalClusterManager
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.scheduler.local.LocalSchedulerBackend
+import org.apache.spark.storage.BlockManagerMaster
import org.apache.spark.util.ManualClock
/**
@@ -898,12 +901,7 @@ class ExecutorAllocationManagerSuite
assert(maxNumExecutorsNeeded(manager) === 0)
schedule(manager)
- // Verify executor is timeout but numExecutorsTarget is not recalculated
- assert(numExecutorsTarget(manager) === 3)
-
- // Schedule again to recalculate the numExecutorsTarget after executor is timeout
- schedule(manager)
- // Verify that current number of executors should be ramp down when executor is timeout
+ // Verify executor is timeout,numExecutorsTarget is recalculated
assert(numExecutorsTarget(manager) === 2)
}
@@ -1050,6 +1048,85 @@ class ExecutorAllocationManagerSuite
assert(removeTimes(manager) === Map.empty)
}
+ test("SPARK-23365 Don't update target num executors when killing idle executors") {
+ val minExecutors = 1
+ val initialExecutors = 1
+ val maxExecutors = 2
+ val conf = new SparkConf()
+ .set("spark.dynamicAllocation.enabled", "true")
+ .set("spark.shuffle.service.enabled", "true")
+ .set("spark.dynamicAllocation.minExecutors", minExecutors.toString)
+ .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString)
+ .set("spark.dynamicAllocation.initialExecutors", initialExecutors.toString)
+ .set("spark.dynamicAllocation.schedulerBacklogTimeout", "1000ms")
+ .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", "1000ms")
+ .set("spark.dynamicAllocation.executorIdleTimeout", s"3000ms")
+ val mockAllocationClient = mock(classOf[ExecutorAllocationClient])
+ val mockBMM = mock(classOf[BlockManagerMaster])
+ val manager = new ExecutorAllocationManager(
+ mockAllocationClient, mock(classOf[LiveListenerBus]), conf, mockBMM)
+ val clock = new ManualClock()
+ manager.setClock(clock)
+
+ when(mockAllocationClient.requestTotalExecutors(meq(2), any(), any())).thenReturn(true)
+ // test setup -- job with 2 tasks, scale up to two executors
+ assert(numExecutorsTarget(manager) === 1)
+ manager.listener.onExecutorAdded(SparkListenerExecutorAdded(
+ clock.getTimeMillis(), "executor-1", new ExecutorInfo("host1", 1, Map.empty)))
+ manager.listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 2)))
+ clock.advance(1000)
+ manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis())
+ assert(numExecutorsTarget(manager) === 2)
+ val taskInfo0 = createTaskInfo(0, 0, "executor-1")
+ manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo0))
+ manager.listener.onExecutorAdded(SparkListenerExecutorAdded(
+ clock.getTimeMillis(), "executor-2", new ExecutorInfo("host1", 1, Map.empty)))
+ val taskInfo1 = createTaskInfo(1, 1, "executor-2")
+ manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo1))
+ assert(numExecutorsTarget(manager) === 2)
+
+ // have one task finish -- we should adjust the target number of executors down
+ // but we should *not* kill any executors yet
+ manager.listener.onTaskEnd(SparkListenerTaskEnd(0, 0, null, Success, taskInfo0, null))
+ assert(maxNumExecutorsNeeded(manager) === 1)
+ assert(numExecutorsTarget(manager) === 2)
+ clock.advance(1000)
+ manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis())
+ assert(numExecutorsTarget(manager) === 1)
+ verify(mockAllocationClient, never).killExecutors(any(), any(), any(), any())
+
+ // now we cross the idle timeout for executor-1, so we kill it. the really important
+ // thing here is that we do *not* ask the executor allocation client to adjust the target
+ // number of executors down
+ when(mockAllocationClient.killExecutors(Seq("executor-1"), false, false, false))
+ .thenReturn(Seq("executor-1"))
+ clock.advance(3000)
+ schedule(manager)
+ assert(maxNumExecutorsNeeded(manager) === 1)
+ assert(numExecutorsTarget(manager) === 1)
+ // here's the important verify -- we did kill the executors, but did not adjust the target count
+ verify(mockAllocationClient).killExecutors(Seq("executor-1"), false, false, false)
+ }
+
+ test("SPARK-26758 check executor target number after idle time out ") {
+ sc = createSparkContext(1, 5, 3)
+ val manager = sc.executorAllocationManager.get
+ val clock = new ManualClock(10000L)
+ manager.setClock(clock)
+ assert(numExecutorsTarget(manager) === 3)
+ manager.listener.onExecutorAdded(SparkListenerExecutorAdded(
+ clock.getTimeMillis(), "executor-1", new ExecutorInfo("host1", 1, Map.empty)))
+ manager.listener.onExecutorAdded(SparkListenerExecutorAdded(
+ clock.getTimeMillis(), "executor-2", new ExecutorInfo("host1", 2, Map.empty)))
+ manager.listener.onExecutorAdded(SparkListenerExecutorAdded(
+ clock.getTimeMillis(), "executor-3", new ExecutorInfo("host1", 3, Map.empty)))
+ // make all the executors as idle, so that it will be killed
+ clock.advance(executorIdleTimeout * 1000)
+ schedule(manager)
+ // once the schedule is run target executor number should be 1
+ assert(numExecutorsTarget(manager) === 1)
+ }
+
private def createSparkContext(
minExecutors: Int = 1,
maxExecutors: Int = 5,
@@ -1268,7 +1345,8 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend
override def killExecutors(
executorIds: Seq[String],
- replace: Boolean,
+ adjustTargetNumExecutors: Boolean,
+ countFailures: Boolean,
force: Boolean): Seq[String] = executorIds
override def start(): Unit = sb.start()
diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala
index e9539dc73f6f..55a9122cf902 100644
--- a/core/src/test/scala/org/apache/spark/FileSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -244,7 +244,10 @@ class FileSuite extends SparkFunSuite with LocalSparkContext {
for (i <- 0 until testOutputCopies) {
// Shift values by i so that they're different in the output
val alteredOutput = testOutput.map(b => (b + i).toByte)
- channel.write(ByteBuffer.wrap(alteredOutput))
+ val buffer = ByteBuffer.wrap(alteredOutput)
+ while (buffer.hasRemaining) {
+ channel.write(buffer)
+ }
}
channel.close()
file.close()
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index 8a77aea75a99..61da4138896c 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.util.concurrent.Semaphore
+import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
@@ -26,7 +27,7 @@ import scala.concurrent.duration._
import org.scalatest.BeforeAndAfter
import org.scalatest.Matchers
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart}
import org.apache.spark.util.ThreadUtils
/**
@@ -40,6 +41,10 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
override def afterEach() {
try {
resetSparkContext()
+ JobCancellationSuite.taskStartedSemaphore.drainPermits()
+ JobCancellationSuite.taskCancelledSemaphore.drainPermits()
+ JobCancellationSuite.twoJobsSharingStageSemaphore.drainPermits()
+ JobCancellationSuite.executionOfInterruptibleCounter.set(0)
} finally {
super.afterEach()
}
@@ -320,6 +325,67 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
f2.get()
}
+ test("interruptible iterator of shuffle reader") {
+ // In this test case, we create a Spark job of two stages. The second stage is cancelled during
+ // execution and a counter is used to make sure that the corresponding tasks are indeed
+ // cancelled.
+ import JobCancellationSuite._
+ sc = new SparkContext("local[2]", "test interruptible iterator")
+
+ // Increase the number of elements to be proceeded to avoid this test being flaky.
+ val numElements = 10000
+ val taskCompletedSem = new Semaphore(0)
+
+ sc.addSparkListener(new SparkListener {
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+ // release taskCancelledSemaphore when cancelTasks event has been posted
+ if (stageCompleted.stageInfo.stageId == 1) {
+ taskCancelledSemaphore.release(numElements)
+ }
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+ if (taskEnd.stageId == 1) { // make sure tasks are completed
+ taskCompletedSem.release()
+ }
+ }
+ })
+
+ // Explicitly disable interrupt task thread on cancelling tasks, so the task thread can only be
+ // interrupted by `InterruptibleIterator`.
+ sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false")
+
+ val f = sc.parallelize(1 to numElements).map { i => (i, i) }
+ .repartitionAndSortWithinPartitions(new HashPartitioner(1))
+ .mapPartitions { iter =>
+ taskStartedSemaphore.release()
+ iter
+ }.foreachAsync { x =>
+ // Block this code from being executed, until the job get cancelled. In this case, if the
+ // source iterator is interruptible, the max number of increment should be under
+ // `numElements`.
+ taskCancelledSemaphore.acquire()
+ executionOfInterruptibleCounter.getAndIncrement()
+ }
+
+ taskStartedSemaphore.acquire()
+ // Job is cancelled when:
+ // 1. task in reduce stage has been started, guaranteed by previous line.
+ // 2. task in reduce stage is blocked as taskCancelledSemaphore is not released until
+ // JobCancelled event is posted.
+ // After job being cancelled, task in reduce stage will be cancelled asynchronously, thus
+ // partial of the inputs should not get processed (It's very unlikely that Spark can process
+ // 10000 elements between JobCancelled is posted and task is really killed).
+ f.cancel()
+
+ val e = intercept[SparkException](f.get()).getCause
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+
+ // Make sure tasks are indeed completed.
+ taskCompletedSem.acquire()
+ assert(executionOfInterruptibleCounter.get() < numElements)
+ }
+
def testCount() {
// Cancel before launching any tasks
{
@@ -381,7 +447,9 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
object JobCancellationSuite {
+ // To avoid any headaches, reset these global variables in the companion class's afterEach block
val taskStartedSemaphore = new Semaphore(0)
val taskCancelledSemaphore = new Semaphore(0)
val twoJobsSharingStageSemaphore = new Semaphore(0)
+ val executionOfInterruptibleCounter = new AtomicInteger(0)
}
diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
index 155ca17db726..9206b5debf4f 100644
--- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
@@ -262,14 +262,11 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva
test("defaultPartitioner") {
val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150)
- val rdd2 = sc
- .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4)))
+ val rdd2 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4)))
.partitionBy(new HashPartitioner(10))
- val rdd3 = sc
- .parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14)))
+ val rdd3 = sc.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14)))
.partitionBy(new HashPartitioner(100))
- val rdd4 = sc
- .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4)))
+ val rdd4 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4)))
.partitionBy(new HashPartitioner(9))
val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11)
@@ -284,7 +281,42 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva
assert(partitioner3.numPartitions == rdd3.getNumPartitions)
assert(partitioner4.numPartitions == rdd3.getNumPartitions)
assert(partitioner5.numPartitions == rdd4.getNumPartitions)
+ }
+ test("defaultPartitioner when defaultParallelism is set") {
+ assert(!sc.conf.contains("spark.default.parallelism"))
+ try {
+ sc.conf.set("spark.default.parallelism", "4")
+
+ val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150)
+ val rdd2 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4)))
+ .partitionBy(new HashPartitioner(10))
+ val rdd3 = sc.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14)))
+ .partitionBy(new HashPartitioner(100))
+ val rdd4 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4)))
+ .partitionBy(new HashPartitioner(9))
+ val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11)
+ val rdd6 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4)))
+ .partitionBy(new HashPartitioner(3))
+
+ val partitioner1 = Partitioner.defaultPartitioner(rdd1, rdd2)
+ val partitioner2 = Partitioner.defaultPartitioner(rdd2, rdd3)
+ val partitioner3 = Partitioner.defaultPartitioner(rdd3, rdd1)
+ val partitioner4 = Partitioner.defaultPartitioner(rdd1, rdd2, rdd3)
+ val partitioner5 = Partitioner.defaultPartitioner(rdd4, rdd5)
+ val partitioner6 = Partitioner.defaultPartitioner(rdd5, rdd5)
+ val partitioner7 = Partitioner.defaultPartitioner(rdd1, rdd6)
+
+ assert(partitioner1.numPartitions == rdd2.getNumPartitions)
+ assert(partitioner2.numPartitions == rdd3.getNumPartitions)
+ assert(partitioner3.numPartitions == rdd3.getNumPartitions)
+ assert(partitioner4.numPartitions == rdd3.getNumPartitions)
+ assert(partitioner5.numPartitions == rdd4.getNumPartitions)
+ assert(partitioner6.numPartitions == sc.defaultParallelism)
+ assert(partitioner7.numPartitions == sc.defaultParallelism)
+ } finally {
+ sc.conf.remove("spark.default.parallelism")
+ }
}
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 3931d53b4ae0..ced5a06516f7 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -363,14 +363,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// first attempt -- its successful
val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
- new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem))
+ new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem))
val data1 = (1 to 10).map { x => x -> x}
// second attempt -- also successful. We'll write out different data,
// just to simulate the fact that the records may get written differently
// depending on what gets spilled, what gets combined, etc.
val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
- new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem))
+ new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem))
val data2 = (11 to 20).map { x => x -> x}
// interleave writes of both attempts -- we want to test that both attempts can occur
@@ -398,7 +398,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
}
val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
- new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem))
+ new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem))
val readData = reader.read().toIndexedSeq
assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
index 8feb3dee050d..051a13c9413e 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
@@ -60,6 +60,7 @@ class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext {
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
assert(sc.getRDDStorageInfo.size === 0)
rdd.collect()
+ sc.listenerBus.waitUntilEmpty(10000)
assert(sc.getRDDStorageInfo.size === 1)
assert(sc.getRDDStorageInfo.head.isCached)
assert(sc.getRDDStorageInfo.head.memSize > 0)
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index b30bd74812b3..ce9f2be1c02d 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark
import java.io.File
import java.net.{MalformedURLException, URI}
import java.nio.charset.StandardCharsets
-import java.util.concurrent.{Semaphore, TimeUnit}
+import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit}
import scala.concurrent.duration._
@@ -498,45 +498,36 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
test("Cancelling stages/jobs with custom reasons.") {
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
+ sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true")
val REASON = "You shall not pass"
- val slices = 10
- val listener = new SparkListener {
- override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
- if (SparkContextSuite.cancelStage) {
- eventually(timeout(10.seconds)) {
- assert(SparkContextSuite.isTaskStarted)
+ for (cancelWhat <- Seq("stage", "job")) {
+ // This countdown latch used to make sure stage or job canceled in listener
+ val latch = new CountDownLatch(1)
+
+ val listener = cancelWhat match {
+ case "stage" =>
+ new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
+ sc.cancelStage(taskStart.stageId, REASON)
+ latch.countDown()
+ }
}
- sc.cancelStage(taskStart.stageId, REASON)
- SparkContextSuite.cancelStage = false
- SparkContextSuite.semaphore.release(slices)
- }
- }
-
- override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
- if (SparkContextSuite.cancelJob) {
- eventually(timeout(10.seconds)) {
- assert(SparkContextSuite.isTaskStarted)
+ case "job" =>
+ new SparkListener {
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+ sc.cancelJob(jobStart.jobId, REASON)
+ latch.countDown()
+ }
}
- sc.cancelJob(jobStart.jobId, REASON)
- SparkContextSuite.cancelJob = false
- SparkContextSuite.semaphore.release(slices)
- }
}
- }
- sc.addSparkListener(listener)
-
- for (cancelWhat <- Seq("stage", "job")) {
- SparkContextSuite.semaphore.drainPermits()
- SparkContextSuite.isTaskStarted = false
- SparkContextSuite.cancelStage = (cancelWhat == "stage")
- SparkContextSuite.cancelJob = (cancelWhat == "job")
+ sc.addSparkListener(listener)
val ex = intercept[SparkException] {
- sc.range(0, 10000L, numSlices = slices).mapPartitions { x =>
- SparkContextSuite.isTaskStarted = true
- // Block waiting for the listener to cancel the stage or job.
- SparkContextSuite.semaphore.acquire()
+ sc.range(0, 10000L, numSlices = 10).mapPartitions { x =>
+ x.synchronized {
+ x.wait()
+ }
x
}.count()
}
@@ -550,9 +541,11 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.")
}
+ latch.await(20, TimeUnit.SECONDS)
eventually(timeout(20.seconds)) {
assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0)
}
+ sc.removeSparkListener(listener)
}
}
@@ -637,8 +630,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
}
object SparkContextSuite {
- @volatile var cancelJob = false
- @volatile var cancelStage = false
@volatile var isTaskStarted = false
@volatile var taskKilled = false
@volatile var taskSucceeded = false
diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
index 05b4e67412f2..6f9b583898c3 100644
--- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
@@ -18,9 +18,13 @@
package org.apache.spark.api.python
import java.io.{ByteArrayOutputStream, DataOutputStream}
+import java.net.{InetAddress, Socket}
import java.nio.charset.StandardCharsets
-import org.apache.spark.SparkFunSuite
+import scala.concurrent.duration.Duration
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.security.SocketAuthHelper
class PythonRDDSuite extends SparkFunSuite {
@@ -44,4 +48,21 @@ class PythonRDDSuite extends SparkFunSuite {
("a".getBytes(StandardCharsets.UTF_8), null),
(null, "b".getBytes(StandardCharsets.UTF_8))), buffer)
}
+
+ test("python server error handling") {
+ val authHelper = new SocketAuthHelper(new SparkConf())
+ val errorServer = new ExceptionPythonServer(authHelper)
+ val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port)
+ authHelper.authToServer(client)
+ val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) }
+ assert(ex.getCause().getMessage().contains("exception within handleConnection"))
+ }
+
+ class ExceptionPythonServer(authHelper: SocketAuthHelper)
+ extends PythonServer[Unit](authHelper, "error-server") {
+
+ override def handleConnection(sock: Socket): Unit = {
+ throw new Exception("exception within handleConnection")
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index 159629825c67..9ad2e9a5e74a 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -153,6 +153,40 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio
assert(broadcast.value.sum === 10)
}
+ test("One broadcast value instance per executor") {
+ val conf = new SparkConf()
+ .setMaster("local[4]")
+ .setAppName("test")
+
+ sc = new SparkContext(conf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val instances = sc.parallelize(1 to 10)
+ .map(x => System.identityHashCode(broadcast.value))
+ .collect()
+ .toSet
+
+ assert(instances.size === 1)
+ }
+
+ test("One broadcast value instance per executor when memory is constrained") {
+ val conf = new SparkConf()
+ .setMaster("local[4]")
+ .setAppName("test")
+ .set("spark.memory.useLegacyMode", "true")
+ .set("spark.storage.memoryFraction", "0.0")
+
+ sc = new SparkContext(conf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val instances = sc.parallelize(1 to 10)
+ .map(x => System.identityHashCode(broadcast.value))
+ .collect()
+ .toSet
+
+ assert(instances.size === 1)
+ }
+
/**
* Verify the persistence of state associated with a TorrentBroadcast in a local-cluster.
*
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala
deleted file mode 100644
index ab24a76e20a3..000000000000
--- a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala
+++ /dev/null
@@ -1,97 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.deploy
-
-import java.security.PrivilegedExceptionAction
-
-import scala.util.Random
-
-import org.apache.hadoop.fs.FileStatus
-import org.apache.hadoop.fs.permission.{FsAction, FsPermission}
-import org.apache.hadoop.security.UserGroupInformation
-import org.scalatest.Matchers
-
-import org.apache.spark.SparkFunSuite
-
-class SparkHadoopUtilSuite extends SparkFunSuite with Matchers {
- test("check file permission") {
- import FsAction._
- val testUser = s"user-${Random.nextInt(100)}"
- val testGroups = Array(s"group-${Random.nextInt(100)}")
- val testUgi = UserGroupInformation.createUserForTesting(testUser, testGroups)
-
- testUgi.doAs(new PrivilegedExceptionAction[Void] {
- override def run(): Void = {
- val sparkHadoopUtil = new SparkHadoopUtil
-
- // If file is owned by user and user has access permission
- var status = fileStatus(testUser, testGroups.head, READ_WRITE, READ_WRITE, NONE)
- sparkHadoopUtil.checkAccessPermission(status, READ) should be(true)
- sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true)
-
- // If file is owned by user but user has no access permission
- status = fileStatus(testUser, testGroups.head, NONE, READ_WRITE, NONE)
- sparkHadoopUtil.checkAccessPermission(status, READ) should be(false)
- sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false)
-
- val otherUser = s"test-${Random.nextInt(100)}"
- val otherGroup = s"test-${Random.nextInt(100)}"
-
- // If file is owned by user's group and user's group has access permission
- status = fileStatus(otherUser, testGroups.head, NONE, READ_WRITE, NONE)
- sparkHadoopUtil.checkAccessPermission(status, READ) should be(true)
- sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true)
-
- // If file is owned by user's group but user's group has no access permission
- status = fileStatus(otherUser, testGroups.head, READ_WRITE, NONE, NONE)
- sparkHadoopUtil.checkAccessPermission(status, READ) should be(false)
- sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false)
-
- // If file is owned by other user and this user has access permission
- status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, READ_WRITE)
- sparkHadoopUtil.checkAccessPermission(status, READ) should be(true)
- sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true)
-
- // If file is owned by other user but this user has no access permission
- status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, NONE)
- sparkHadoopUtil.checkAccessPermission(status, READ) should be(false)
- sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false)
-
- null
- }
- })
- }
-
- private def fileStatus(
- owner: String,
- group: String,
- userAction: FsAction,
- groupAction: FsAction,
- otherAction: FsAction): FileStatus = {
- new FileStatus(0L,
- false,
- 0,
- 0L,
- 0L,
- 0L,
- new FsPermission(userAction, groupAction, otherAction),
- owner,
- group,
- null)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 27dd43533234..2fbad1f6a13e 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.deploy
import java.io._
import java.net.URI
import java.nio.charset.StandardCharsets
-import java.nio.file.Files
+import java.nio.file.{Files, Paths}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
@@ -105,6 +105,9 @@ class SparkSubmitSuite
// Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x
implicit val defaultSignaler: Signaler = ThreadSignaler
+ private val emptyIvySettings = File.createTempFile("ivy", ".xml")
+ FileUtils.write(emptyIvySettings, "", StandardCharsets.UTF_8)
+
override def beforeEach() {
super.beforeEach()
System.setProperty("spark.testing", "true")
@@ -520,6 +523,7 @@ class SparkSubmitSuite
"--repositories", repo,
"--conf", "spark.ui.enabled=false",
"--conf", "spark.master.rest.enabled=false",
+ "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}",
unusedJar.toString,
"my.great.lib.MyLib", "my.great.dep.MyLib")
runSparkSubmit(args)
@@ -530,7 +534,6 @@ class SparkSubmitSuite
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val main = MavenCoordinate("my.great.lib", "mylib", "0.1")
val dep = MavenCoordinate("my.great.dep", "mylib", "0.1")
- // Test using "spark.jars.packages" and "spark.jars.repositories" configurations.
IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo =>
val args = Seq(
"--class", JarCreationTest.getClass.getName.stripSuffix("$"),
@@ -540,6 +543,7 @@ class SparkSubmitSuite
"--conf", s"spark.jars.repositories=$repo",
"--conf", "spark.ui.enabled=false",
"--conf", "spark.master.rest.enabled=false",
+ "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}",
unusedJar.toString,
"my.great.lib.MyLib", "my.great.dep.MyLib")
runSparkSubmit(args)
@@ -550,7 +554,6 @@ class SparkSubmitSuite
// See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log
ignore("correctly builds R packages included in a jar with --packages") {
assume(RUtils.isRInstalled, "R isn't installed on this machine.")
- // Check if the SparkR package is installed
assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.")
val main = MavenCoordinate("my.great.lib", "mylib", "0.1")
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
@@ -563,6 +566,7 @@ class SparkSubmitSuite
"--master", "local-cluster[2,1,1024]",
"--packages", main.toString,
"--repositories", repo,
+ "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}",
"--verbose",
"--conf", "spark.ui.enabled=false",
rScriptDir)
@@ -573,7 +577,6 @@ class SparkSubmitSuite
test("include an external JAR in SparkR") {
assume(RUtils.isRInstalled, "R isn't installed on this machine.")
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
- // Check if the SparkR package is installed
assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.")
val rScriptDir =
Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator)
@@ -606,10 +609,13 @@ class SparkSubmitSuite
}
test("resolves command line argument paths correctly") {
- val jars = "/jar1,/jar2" // --jars
- val files = "local:/file1,file2" // --files
- val archives = "file:/archive1,archive2" // --archives
- val pyFiles = "py-file1,py-file2" // --py-files
+ val dir = Utils.createTempDir()
+ val archive = Paths.get(dir.toPath.toString, "single.zip")
+ Files.createFile(archive)
+ val jars = "/jar1,/jar2"
+ val files = "local:/file1,file2"
+ val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3"
+ val pyFiles = "py-file1,py-file2"
// Test jars and files
val clArgs = Seq(
@@ -636,9 +642,10 @@ class SparkSubmitSuite
val appArgs2 = new SparkSubmitArguments(clArgs2)
val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2)
appArgs2.files should be (Utils.resolveURIs(files))
- appArgs2.archives should be (Utils.resolveURIs(archives))
+ appArgs2.archives should fullyMatch regex ("file:/archive1,file:.*#archive3")
conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files))
- conf2.get("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives))
+ conf2.get("spark.yarn.dist.archives") should fullyMatch regex
+ ("file:/archive1,file:.*#archive3")
// Test python files
val clArgs3 = Seq(
@@ -657,6 +664,29 @@ class SparkSubmitSuite
conf3.get(PYSPARK_PYTHON.key) should be ("python3.5")
}
+ test("ambiguous archive mapping results in error message") {
+ val dir = Utils.createTempDir()
+ val archive1 = Paths.get(dir.toPath.toString, "first.zip")
+ val archive2 = Paths.get(dir.toPath.toString, "second.zip")
+ Files.createFile(archive1)
+ Files.createFile(archive2)
+ val jars = "/jar1,/jar2"
+ val files = "local:/file1,file2"
+ val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3"
+ val pyFiles = "py-file1,py-file2"
+
+ // Test files and archives (Yarn)
+ val clArgs2 = Seq(
+ "--master", "yarn",
+ "--class", "org.SomeClass",
+ "--files", files,
+ "--archives", archives,
+ "thejar.jar"
+ )
+
+ testPrematureExit(clArgs2.toArray, "resolves ambiguously to multiple files")
+ }
+
test("resolves config paths correctly") {
val jars = "/jar1,/jar2" // spark.jars
val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files
@@ -906,6 +936,25 @@ class SparkSubmitSuite
}
}
+ test("remove copies of application jar from classpath") {
+ val fs = File.separator
+ val sparkConf = new SparkConf(false)
+ val hadoopConf = new Configuration()
+ val secMgr = new SecurityManager(sparkConf)
+
+ val appJarName = "myApp.jar"
+ val jar1Name = "myJar1.jar"
+ val jar2Name = "myJar2.jar"
+ val userJar = s"file:/path${fs}to${fs}app${fs}jar$fs$appJarName"
+ val jars = s"file:/$jar1Name,file:/$appJarName,file:/$jar2Name"
+
+ val resolvedJars = DependencyUtils
+ .resolveAndDownloadJars(jars, userJar, sparkConf, hadoopConf, secMgr)
+
+ assert(!resolvedJars.contains(appJarName))
+ assert(resolvedJars.contains(jar1Name) && resolvedJars.contains(jar2Name))
+ }
+
test("Avoid re-upload remote resources in yarn client mode") {
val hadoopConf = new Configuration()
updateConfWithFakeS3Fs(hadoopConf)
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
index eb8c203ae775..a0f09891787e 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
@@ -256,4 +256,19 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(jarPath.indexOf("mydep") >= 0, "should find dependency")
}
}
+
+ test("SPARK-10878: test resolution files cleaned after resolving artifact") {
+ val main = new MavenCoordinate("my.great.lib", "mylib", "0.1")
+
+ IvyTestUtils.withRepository(main, None, None) { repo =>
+ val ivySettings = SparkSubmitUtils.buildIvySettings(Some(repo), Some(tempIvyPath))
+ val jarPath = SparkSubmitUtils.resolveMavenCoordinates(
+ main.toString,
+ ivySettings,
+ isTest = true)
+ val r = """.*org.apache.spark-spark-submit-parent-.*""".r
+ assert(!ivySettings.getDefaultCache.listFiles.map(_.getName)
+ .exists(r.findFirstIn(_).isDefined), "resolution files should be cleaned")
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
index bf7480d79f8a..155564a65c60 100644
--- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
@@ -573,7 +573,8 @@ class StandaloneDynamicAllocationSuite
syncExecutors(sc)
sc.schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
- b.killExecutors(Seq(executorId), replace = false, force)
+ b.killExecutors(Seq(executorId), adjustTargetNumExecutors = true, countFailures = false,
+ force)
case _ => fail("expected coarse grained scheduler")
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
index 84ee01c7f5aa..bf2b04484a42 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
@@ -29,9 +29,11 @@ import scala.language.postfixOps
import com.google.common.io.{ByteStreams, Files}
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.hdfs.DistributedFileSystem
+import org.apache.hadoop.security.AccessControlException
import org.json4s.jackson.JsonMethods._
-import org.mockito.Matchers.any
-import org.mockito.Mockito.{mock, spy, verify}
+import org.mockito.ArgumentMatcher
+import org.mockito.Matchers.{any, argThat}
+import org.mockito.Mockito.{doReturn, doThrow, mock, spy, verify, when}
import org.scalatest.BeforeAndAfter
import org.scalatest.Matchers
import org.scalatest.concurrent.Eventually._
@@ -149,8 +151,10 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
class TestFsHistoryProvider extends FsHistoryProvider(createTestConf()) {
var mergeApplicationListingCall = 0
- override protected def mergeApplicationListing(fileStatus: FileStatus): Unit = {
- super.mergeApplicationListing(fileStatus)
+ override protected def mergeApplicationListing(
+ fileStatus: FileStatus,
+ lastSeen: Long): Unit = {
+ super.mergeApplicationListing(fileStatus, lastSeen)
mergeApplicationListingCall += 1
}
}
@@ -663,6 +667,151 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
freshUI.get.ui.store.job(0)
}
+ test("clean up stale app information") {
+ val storeDir = Utils.createTempDir()
+ val conf = createTestConf().set(LOCAL_STORE_DIR, storeDir.getAbsolutePath())
+ val provider = spy(new FsHistoryProvider(conf))
+ val appId = "new1"
+
+ // Write logs for two app attempts.
+ doReturn(1L).when(provider).getNewLastScanTime()
+ val attempt1 = newLogFile(appId, Some("1"), inProgress = false)
+ writeFile(attempt1, true, None,
+ SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("1")),
+ SparkListenerJobStart(0, 1L, Nil, null),
+ SparkListenerApplicationEnd(5L)
+ )
+ val attempt2 = newLogFile(appId, Some("2"), inProgress = false)
+ writeFile(attempt2, true, None,
+ SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("2")),
+ SparkListenerJobStart(0, 1L, Nil, null),
+ SparkListenerApplicationEnd(5L)
+ )
+ updateAndCheck(provider) { list =>
+ assert(list.size === 1)
+ assert(list(0).id === appId)
+ assert(list(0).attempts.size === 2)
+ }
+
+ // Load the app's UI.
+ val ui = provider.getAppUI(appId, Some("1"))
+ assert(ui.isDefined)
+
+ // Delete the underlying log file for attempt 1 and rescan. The UI should go away, but since
+ // attempt 2 still exists, listing data should be there.
+ doReturn(2L).when(provider).getNewLastScanTime()
+ attempt1.delete()
+ updateAndCheck(provider) { list =>
+ assert(list.size === 1)
+ assert(list(0).id === appId)
+ assert(list(0).attempts.size === 1)
+ }
+ assert(!ui.get.valid)
+ assert(provider.getAppUI(appId, None) === None)
+
+ // Delete the second attempt's log file. Now everything should go away.
+ doReturn(3L).when(provider).getNewLastScanTime()
+ attempt2.delete()
+ updateAndCheck(provider) { list =>
+ assert(list.isEmpty)
+ }
+ }
+
+ test("SPARK-21571: clean up removes invalid history files") {
+ // TODO: "maxTime" becoming negative in cleanLogs() causes this test to fail, so avoid that
+ // until we figure out what's causing the problem.
+ val clock = new ManualClock(TimeUnit.DAYS.toMillis(120))
+ val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d")
+ val provider = new FsHistoryProvider(conf, clock) {
+ override def getNewLastScanTime(): Long = clock.getTimeMillis()
+ }
+
+ // Create 0-byte size inprogress and complete files
+ var logCount = 0
+ var validLogCount = 0
+
+ val emptyInProgress = newLogFile("emptyInprogressLogFile", None, inProgress = true)
+ emptyInProgress.createNewFile()
+ emptyInProgress.setLastModified(clock.getTimeMillis())
+ logCount += 1
+
+ val slowApp = newLogFile("slowApp", None, inProgress = true)
+ slowApp.createNewFile()
+ slowApp.setLastModified(clock.getTimeMillis())
+ logCount += 1
+
+ val emptyFinished = newLogFile("emptyFinishedLogFile", None, inProgress = false)
+ emptyFinished.createNewFile()
+ emptyFinished.setLastModified(clock.getTimeMillis())
+ logCount += 1
+
+ // Create an incomplete log file, has an end record but no start record.
+ val corrupt = newLogFile("nonEmptyCorruptLogFile", None, inProgress = false)
+ writeFile(corrupt, true, None, SparkListenerApplicationEnd(0))
+ corrupt.setLastModified(clock.getTimeMillis())
+ logCount += 1
+
+ provider.checkForLogs()
+ provider.cleanLogs()
+ assert(new File(testDir.toURI).listFiles().size === logCount)
+
+ // Move the clock forward 1 day and scan the files again. They should still be there.
+ clock.advance(TimeUnit.DAYS.toMillis(1))
+ provider.checkForLogs()
+ provider.cleanLogs()
+ assert(new File(testDir.toURI).listFiles().size === logCount)
+
+ // Update the slow app to contain valid info. Code should detect the change and not clean
+ // it up.
+ writeFile(slowApp, true, None,
+ SparkListenerApplicationStart(slowApp.getName(), Some(slowApp.getName()), 1L, "test", None))
+ slowApp.setLastModified(clock.getTimeMillis())
+ validLogCount += 1
+
+ // Move the clock forward another 2 days and scan the files again. This time the cleaner should
+ // pick up the invalid files and get rid of them.
+ clock.advance(TimeUnit.DAYS.toMillis(2))
+ provider.checkForLogs()
+ provider.cleanLogs()
+ assert(new File(testDir.toURI).listFiles().size === validLogCount)
+ }
+
+ test("SPARK-24948: blacklist files we don't have read permission on") {
+ val clock = new ManualClock(1533132471)
+ val provider = new FsHistoryProvider(createTestConf(), clock)
+ val accessDenied = newLogFile("accessDenied", None, inProgress = false)
+ writeFile(accessDenied, true, None,
+ SparkListenerApplicationStart("accessDenied", Some("accessDenied"), 1L, "test", None))
+ val accessGranted = newLogFile("accessGranted", None, inProgress = false)
+ writeFile(accessGranted, true, None,
+ SparkListenerApplicationStart("accessGranted", Some("accessGranted"), 1L, "test", None),
+ SparkListenerApplicationEnd(5L))
+ val mockedFs = spy(provider.fs)
+ doThrow(new AccessControlException("Cannot read accessDenied file")).when(mockedFs).open(
+ argThat(new ArgumentMatcher[Path]() {
+ override def matches(path: Any): Boolean = {
+ path.asInstanceOf[Path].getName.toLowerCase == "accessdenied"
+ }
+ }))
+ val mockedProvider = spy(provider)
+ when(mockedProvider.fs).thenReturn(mockedFs)
+ updateAndCheck(mockedProvider) { list =>
+ list.size should be(1)
+ }
+ writeFile(accessDenied, true, None,
+ SparkListenerApplicationStart("accessDenied", Some("accessDenied"), 1L, "test", None),
+ SparkListenerApplicationEnd(5L))
+ // Doing 2 times in order to check the blacklist filter too
+ updateAndCheck(mockedProvider) { list =>
+ list.size should be(1)
+ }
+ val accessDeniedPath = new Path(accessDenied.getPath)
+ assert(mockedProvider.isBlacklisted(accessDeniedPath))
+ clock.advance(24 * 60 * 60 * 1000 + 1) // add a bit more than 1d
+ mockedProvider.cleanLogs()
+ assert(!mockedProvider.isBlacklisted(accessDeniedPath))
+ }
+
/**
* Asks the provider to check for logs and calls a function to perform checks on the updated
* app list. Example:
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
index 3738f85da583..4c0619322536 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
@@ -48,7 +48,7 @@ import org.apache.spark.deploy.history.config._
import org.apache.spark.status.api.v1.ApplicationInfo
import org.apache.spark.status.api.v1.JobData
import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{ResetSystemProperties, Utils}
+import org.apache.spark.util.{ResetSystemProperties, ShutdownHookManager, Utils}
/**
* A collection of tests against the historyserver, including comparing responses from the json
@@ -294,6 +294,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
all (siteRelativeLinks) should startWith (uiRoot)
}
+ test("/version api endpoint") {
+ val response = getUrl("version")
+ assert(response.contains(SPARK_VERSION))
+ }
+
test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") {
val uiRoot = "/testwebproxybase"
System.setProperty("spark.ui.proxyBase", uiRoot)
@@ -564,7 +569,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
assert(jobcount === getNumJobs("/jobs"))
// no need to retain the test dir now the tests complete
- logDir.deleteOnExit()
+ ShutdownHookManager.registerShutdownDeleteDir(logDir)
}
test("ui and api authorization checks") {
diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
index e505bc018857..0f56c6902093 100644
--- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
@@ -376,6 +376,18 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach {
assert(filteredVariables == Map("SPARK_VAR" -> "1"))
}
+ test("client does not send 'SPARK_HOME' env var by default") {
+ val environmentVariables = Map("SPARK_VAR" -> "1", "SPARK_HOME" -> "1")
+ val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables)
+ assert(filteredVariables == Map("SPARK_VAR" -> "1"))
+ }
+
+ test("client does not send 'SPARK_CONF_DIR' env var by default") {
+ val environmentVariables = Map("SPARK_VAR" -> "1", "SPARK_CONF_DIR" -> "1")
+ val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables)
+ assert(filteredVariables == Map("SPARK_VAR" -> "1"))
+ }
+
test("client includes mesos env vars") {
val environmentVariables = Map("SPARK_VAR" -> "1", "MESOS_VAR" -> "1", "OTHER_VAR" -> "1")
val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables)
diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala
index eeffc36070b4..2849a10a2c81 100644
--- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.deploy.security
+import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.security.Credentials
@@ -110,7 +111,64 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers {
creds.getAllTokens.size should be (0)
}
+ test("SPARK-23209: obtain tokens when Hive classes are not available") {
+ // This test needs a custom class loader to hide Hive classes which are in the classpath.
+ // Because the manager code loads the Hive provider directly instead of using reflection, we
+ // need to drive the test through the custom class loader so a new copy that cannot find
+ // Hive classes is loaded.
+ val currentLoader = Thread.currentThread().getContextClassLoader()
+ val noHive = new ClassLoader() {
+ override def loadClass(name: String, resolve: Boolean): Class[_] = {
+ if (name.startsWith("org.apache.hive") || name.startsWith("org.apache.hadoop.hive")) {
+ throw new ClassNotFoundException(name)
+ }
+
+ if (name.startsWith("java") || name.startsWith("scala")) {
+ currentLoader.loadClass(name)
+ } else {
+ val classFileName = name.replaceAll("\\.", "/") + ".class"
+ val in = currentLoader.getResourceAsStream(classFileName)
+ if (in != null) {
+ val bytes = IOUtils.toByteArray(in)
+ defineClass(name, bytes, 0, bytes.length)
+ } else {
+ throw new ClassNotFoundException(name)
+ }
+ }
+ }
+ }
+
+ try {
+ Thread.currentThread().setContextClassLoader(noHive)
+ val test = noHive.loadClass(NoHiveTest.getClass.getName().stripSuffix("$"))
+ test.getMethod("runTest").invoke(null)
+ } finally {
+ Thread.currentThread().setContextClassLoader(currentLoader)
+ }
+ }
+
private[spark] def hadoopFSsToAccess(hadoopConf: Configuration): Set[FileSystem] = {
Set(FileSystem.get(hadoopConf))
}
}
+
+/** Test code for SPARK-23209 to avoid using too much reflection above. */
+private object NoHiveTest extends Matchers {
+
+ def runTest(): Unit = {
+ try {
+ val manager = new HadoopDelegationTokenManager(new SparkConf(), new Configuration(),
+ _ => Set())
+ manager.getServiceDelegationTokenProvider("hive") should be (None)
+ } catch {
+ case e: Throwable =>
+ // Throw a better exception in case the test fails, since there may be a lot of nesting.
+ var cause = e
+ while (cause.getCause() != null) {
+ cause = cause.getCause()
+ }
+ throw cause
+ }
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index 105a178f2d94..1a7bebe2c53c 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -22,6 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable.Map
import scala.concurrent.duration._
@@ -139,7 +140,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
// the fetch failure. The executor should still tell the driver that the task failed due to a
// fetch failure, not a generic exception from user code.
val inputRDD = new FetchFailureThrowingRDD(sc)
- val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false)
+ val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false, interrupt = false)
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
val task = new ResultTask(
@@ -173,17 +174,48 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}
test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
+ val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true)
+ assert(failReason.isInstanceOf[ExceptionFailure])
+ val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
+ verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
+ assert(exceptionCaptor.getAllValues.size === 1)
+ assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError])
+ }
+
+ test("SPARK-23816: interrupts are not masked by a FetchFailure") {
+ // If killing the task causes a fetch failure, we still treat it as a task that was killed,
+ // as the fetch failure could easily be caused by interrupting the thread.
+ val (failReason, _) = testFetchFailureHandling(false)
+ assert(failReason.isInstanceOf[TaskKilled])
+ }
+
+ /**
+ * Helper for testing some cases where a FetchFailure should *not* get sent back, because its
+ * superceded by another error, either an OOM or intentionally killing a task.
+ * @param oom if true, throw an OOM after the FetchFailure; else, interrupt the task after the
+ * FetchFailure
+ */
+ private def testFetchFailureHandling(
+ oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
// when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
// may be a false positive. And we should call the uncaught exception handler.
+ // SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task
+ // does not represent a real fetch failure.
val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
sc = new SparkContext(conf)
val serializer = SparkEnv.get.closureSerializer.newInstance()
val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
- // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat
- // the fetch failure as a false positive, and just do normal OOM handling.
+ // Submit a job where a fetch failure is thrown, but then there is an OOM or interrupt. We
+ // should treat the fetch failure as a false positive, and do normal OOM or interrupt handling.
val inputRDD = new FetchFailureThrowingRDD(sc)
- val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true)
+ if (!oom) {
+ // we are trying to setup a case where a task is killed after a fetch failure -- this
+ // is just a helper to coordinate between the task thread and this thread that will
+ // kill the task
+ ExecutorSuiteHelper.latches = new ExecutorSuiteHelper()
+ }
+ val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = oom, interrupt = !oom)
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
val task = new ResultTask(
@@ -200,15 +232,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
val serTask = serializer.serialize(task)
val taskDescription = createFakeTaskDescription(serTask)
- val (failReason, uncaughtExceptionHandler) =
- runTaskGetFailReasonAndExceptionHandler(taskDescription)
- // make sure the task failure just looks like a OOM, not a fetch failure
- assert(failReason.isInstanceOf[ExceptionFailure])
- val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
- verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
- assert(exceptionCaptor.getAllValues.size === 1)
- assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError])
- }
+ runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = !oom)
+ }
test("Gracefully handle error in task deserialization") {
val conf = new SparkConf
@@ -257,22 +282,39 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}
private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
- runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
+ runTaskGetFailReasonAndExceptionHandler(taskDescription, false)._1
}
private def runTaskGetFailReasonAndExceptionHandler(
- taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = {
+ taskDescription: TaskDescription,
+ killTask: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
val mockBackend = mock[ExecutorBackend]
val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler]
var executor: Executor = null
+ val timedOut = new AtomicBoolean(false)
try {
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
uncaughtExceptionHandler = mockUncaughtExceptionHandler)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockBackend, taskDescription)
+ if (killTask) {
+ val killingThread = new Thread("kill-task") {
+ override def run(): Unit = {
+ // wait to kill the task until it has thrown a fetch failure
+ if (ExecutorSuiteHelper.latches.latch1.await(10, TimeUnit.SECONDS)) {
+ // now we can kill the task
+ executor.killAllTasks(true, "Killed task, eg. because of speculative execution")
+ } else {
+ timedOut.set(true)
+ }
+ }
+ }
+ killingThread.start()
+ }
eventually(timeout(5.seconds), interval(10.milliseconds)) {
assert(executor.numRunningTasks === 0)
}
+ assert(!timedOut.get(), "timed out waiting to be ready to kill tasks")
} finally {
if (executor != null) {
executor.stop()
@@ -282,8 +324,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
+ val finalState = if (killTask) TaskState.KILLED else TaskState.FAILED
orderedMock.verify(mockBackend)
- .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
+ .statusUpdate(meq(0L), meq(finalState), statusCaptor.capture())
// first statusUpdate for RUNNING has empty data
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
// second update is more interesting
@@ -321,7 +364,8 @@ class SimplePartition extends Partition {
class FetchFailureHidingRDD(
sc: SparkContext,
val input: FetchFailureThrowingRDD,
- throwOOM: Boolean) extends RDD[Int](input) {
+ throwOOM: Boolean,
+ interrupt: Boolean) extends RDD[Int](input) {
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
val inItr = input.compute(split, context)
try {
@@ -330,6 +374,15 @@ class FetchFailureHidingRDD(
case t: Throwable =>
if (throwOOM) {
throw new OutOfMemoryError("OOM while handling another exception")
+ } else if (interrupt) {
+ // make sure our test is setup correctly
+ assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined)
+ // signal our test is ready for the task to get killed
+ ExecutorSuiteHelper.latches.latch1.countDown()
+ // then wait for another thread in the test to kill the task -- this latch
+ // is never actually decremented, we just wait to get killed.
+ ExecutorSuiteHelper.latches.latch2.await(10, TimeUnit.SECONDS)
+ throw new IllegalStateException("timed out waiting to be interrupted")
} else {
throw new RuntimeException("User Exception that hides the original exception", t)
}
@@ -352,6 +405,11 @@ private class ExecutorSuiteHelper {
@volatile var testFailedReason: TaskFailedReason = _
}
+// helper for coordinating killing tasks
+private object ExecutorSuiteHelper {
+ var latches: ExecutorSuiteHelper = null
+}
+
private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable {
def writeExternal(out: ObjectOutput): Unit = {}
def readExternal(in: ObjectInput): Unit = {
diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
index 3b798e36b049..2107559572d7 100644
--- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
@@ -21,11 +21,12 @@ import java.nio.ByteBuffer
import com.google.common.io.ByteStreams
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SharedSparkContext, SparkFunSuite}
+import org.apache.spark.internal.config
import org.apache.spark.network.util.ByteArrayWritableChannel
import org.apache.spark.util.io.ChunkedByteBuffer
-class ChunkedByteBufferSuite extends SparkFunSuite {
+class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext {
test("no chunks") {
val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer])
@@ -56,6 +57,18 @@ class ChunkedByteBufferSuite extends SparkFunSuite {
assert(chunkedByteBuffer.getChunks().head.position() === 0)
}
+ test("SPARK-24107: writeFully() write buffer which is larger than bufferWriteChunkSize") {
+ try {
+ sc.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 32L * 1024L * 1024L)
+ val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(40 * 1024 * 1024)))
+ val byteArrayWritableChannel = new ByteArrayWritableChannel(chunkedByteBuffer.size.toInt)
+ chunkedByteBuffer.writeFully(byteArrayWritableChannel)
+ assert(byteArrayWritableChannel.length() === chunkedByteBuffer.size)
+ } finally {
+ sc.conf.remove(config.BUFFER_WRITE_CHUNK_SIZE)
+ }
+ }
+
test("toArray()") {
val empty = ByteBuffer.wrap(Array.empty[Byte])
val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte))
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
index 362cd861cc24..dcf89e4f75ac 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
@@ -29,6 +29,7 @@ object MemoryTestingUtils {
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
new TaskContextImpl(
stageId = 0,
+ stageAttemptNumber = 0,
partitionId = 0,
taskAttemptId = 0,
attemptNumber = 0,
diff --git a/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala
new file mode 100644
index 000000000000..d7e4b9166fa0
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.io.InputStream
+import java.nio.ByteBuffer
+
+import scala.concurrent.Future
+import scala.concurrent.duration._
+import scala.reflect.ClassTag
+
+import org.scalatest.concurrent._
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager}
+import org.apache.spark.storage.{BlockId, StorageLevel}
+
+class BlockTransferServiceSuite extends SparkFunSuite with TimeLimits {
+
+ implicit val defaultSignaler: Signaler = ThreadSignaler
+
+ test("fetchBlockSync should not hang when BlockFetchingListener.onBlockFetchSuccess fails") {
+ // Create a mocked `BlockTransferService` to call `BlockFetchingListener.onBlockFetchSuccess`
+ // with a bad `ManagedBuffer` which will trigger an exception in `onBlockFetchSuccess`.
+ val blockTransferService = new BlockTransferService {
+ override def init(blockDataManager: BlockDataManager): Unit = {}
+
+ override def close(): Unit = {}
+
+ override def port: Int = 0
+
+ override def hostName: String = "localhost-unused"
+
+ override def fetchBlocks(
+ host: String,
+ port: Int,
+ execId: String,
+ blockIds: Array[String],
+ listener: BlockFetchingListener,
+ tempFileManager: DownloadFileManager): Unit = {
+ // Notify BlockFetchingListener with a bad ManagedBuffer asynchronously
+ new Thread() {
+ override def run(): Unit = {
+ // This is a bad buffer to trigger `IllegalArgumentException` in
+ // `BlockFetchingListener.onBlockFetchSuccess`. The real issue we hit is
+ // `ByteBuffer.allocate` throws `OutOfMemoryError`, but we cannot make it happen in
+ // a test. Instead, we use a negative size value to make `ByteBuffer.allocate` fail,
+ // and this should trigger the same code path as `OutOfMemoryError`.
+ val badBuffer = new ManagedBuffer {
+ override def size(): Long = -1
+
+ override def nioByteBuffer(): ByteBuffer = null
+
+ override def createInputStream(): InputStream = null
+
+ override def retain(): ManagedBuffer = this
+
+ override def release(): ManagedBuffer = this
+
+ override def convertToNetty(): AnyRef = null
+ }
+ listener.onBlockFetchSuccess("block-id-unused", badBuffer)
+ }
+ }.start()
+ }
+
+ override def uploadBlock(
+ hostname: String,
+ port: Int,
+ execId: String,
+ blockId: BlockId,
+ blockData: ManagedBuffer,
+ level: StorageLevel,
+ classTag: ClassTag[_]): Future[Unit] = {
+ // This method is unused in this test
+ throw new UnsupportedOperationException("uploadBlock")
+ }
+ }
+
+ val e = intercept[SparkException] {
+ failAfter(10.seconds) {
+ blockTransferService.fetchBlockSync(
+ "localhost-unused", 0, "exec-id-unused", "block-id-unused", null)
+ }
+ }
+ assert(e.getCause.isInstanceOf[IllegalArgumentException])
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index a39e0469272f..47af5c3320dd 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -322,8 +322,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
}
// See SPARK-22465
- test("cogroup between multiple RDD" +
- " with number of partitions similar in order of magnitude") {
+ test("cogroup between multiple RDD with number of partitions similar in order of magnitude") {
val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20)
val rdd2 = sc
.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
@@ -332,6 +331,48 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
assert(joined.getNumPartitions == rdd2.getNumPartitions)
}
+ test("cogroup between multiple RDD when defaultParallelism is set without proper partitioner") {
+ assert(!sc.conf.contains("spark.default.parallelism"))
+ try {
+ sc.conf.set("spark.default.parallelism", "4")
+ val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20)
+ val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)), 10)
+ val joined = rdd1.cogroup(rdd2)
+ assert(joined.getNumPartitions == sc.defaultParallelism)
+ } finally {
+ sc.conf.remove("spark.default.parallelism")
+ }
+ }
+
+ test("cogroup between multiple RDD when defaultParallelism is set with proper partitioner") {
+ assert(!sc.conf.contains("spark.default.parallelism"))
+ try {
+ sc.conf.set("spark.default.parallelism", "4")
+ val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20)
+ val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ .partitionBy(new HashPartitioner(10))
+ val joined = rdd1.cogroup(rdd2)
+ assert(joined.getNumPartitions == rdd2.getNumPartitions)
+ } finally {
+ sc.conf.remove("spark.default.parallelism")
+ }
+ }
+
+ test("cogroup between multiple RDD when defaultParallelism is set; with huge number of " +
+ "partitions in upstream RDDs") {
+ assert(!sc.conf.contains("spark.default.parallelism"))
+ try {
+ sc.conf.set("spark.default.parallelism", "4")
+ val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 1000)
+ val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ .partitionBy(new HashPartitioner(10))
+ val joined = rdd1.cogroup(rdd2)
+ assert(joined.getNumPartitions == rdd2.getNumPartitions)
+ } finally {
+ sc.conf.remove("spark.default.parallelism")
+ }
+ }
+
test("rightOuterJoin") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala
index d3bbfd11d406..fe22d70850c7 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala
@@ -24,7 +24,6 @@ import org.apache.spark.internal.config
class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorMockBackend]{
val badHost = "host-0"
- val duration = Duration(10, SECONDS)
/**
* This backend just always fails if the task is executed on a bad host, but otherwise succeeds
diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
index cd1b7a9e5ab1..00867ef1308a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
@@ -479,7 +479,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
test("blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") {
val allocationClientMock = mock[ExecutorAllocationClient]
- when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called"))
+ when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called"))
when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] {
// To avoid a race between blacklisting and killing, it is important that the nodeBlacklist
// is updated before we ask the executor allocation client to kill all the executors
@@ -517,7 +517,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
}
blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures)
- verify(allocationClientMock, never).killExecutors(any(), any(), any())
+ verify(allocationClientMock, never).killExecutors(any(), any(), any(), any())
verify(allocationClientMock, never).killExecutorsOnHost(any())
// Enable auto-kill. Blacklist an executor and make sure killExecutors is called.
@@ -533,7 +533,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
}
blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures)
- verify(allocationClientMock).killExecutors(Seq("1"), true, true)
+ verify(allocationClientMock).killExecutors(Seq("1"), false, false, true)
val taskSetBlacklist3 = createTaskSetBlacklist(stageId = 1)
// Fail 4 tasks in one task set on executor 2, so that executor gets blacklisted for the whole
@@ -545,13 +545,13 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
}
blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist3.execToFailures)
- verify(allocationClientMock).killExecutors(Seq("2"), true, true)
+ verify(allocationClientMock).killExecutors(Seq("2"), false, false, true)
verify(allocationClientMock).killExecutorsOnHost("hostA")
}
test("fetch failure blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") {
val allocationClientMock = mock[ExecutorAllocationClient]
- when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called"))
+ when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called"))
when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] {
// To avoid a race between blacklisting and killing, it is important that the nodeBlacklist
// is updated before we ask the executor allocation client to kill all the executors
@@ -571,16 +571,19 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
conf.set(config.BLACKLIST_KILL_ENABLED, false)
blacklist.updateBlacklistForFetchFailure("hostA", exec = "1")
- verify(allocationClientMock, never).killExecutors(any(), any(), any())
+ verify(allocationClientMock, never).killExecutors(any(), any(), any(), any())
verify(allocationClientMock, never).killExecutorsOnHost(any())
+ assert(blacklist.nodeToBlacklistedExecs.contains("hostA"))
+ assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1"))
+
// Enable auto-kill. Blacklist an executor and make sure killExecutors is called.
conf.set(config.BLACKLIST_KILL_ENABLED, true)
blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock)
clock.advance(1000)
blacklist.updateBlacklistForFetchFailure("hostA", exec = "1")
- verify(allocationClientMock).killExecutors(Seq("1"), true, true)
+ verify(allocationClientMock).killExecutors(Seq("1"), false, false, true)
verify(allocationClientMock, never).killExecutorsOnHost(any())
assert(blacklist.executorIdToBlacklistStatus.contains("1"))
@@ -589,6 +592,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS)
assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS)
assert(blacklist.nodeIdToBlacklistExpiryTime.isEmpty)
+ assert(blacklist.nodeToBlacklistedExecs.contains("hostA"))
+ assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1"))
// Enable external shuffle service to see if all the executors on this node will be killed.
conf.set(config.SHUFFLE_SERVICE_ENABLED, true)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index d812b5bd92c1..4bfe618f9bd2 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -30,7 +30,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark._
import org.apache.spark.broadcast.BroadcastManager
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException}
import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
@@ -56,6 +56,20 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
}
+class MyCheckpointRDD(
+ sc: SparkContext,
+ numPartitions: Int,
+ dependencies: List[Dependency[_]],
+ locations: Seq[Seq[String]] = Nil,
+ @(transient @param) tracker: MapOutputTrackerMaster = null,
+ indeterminate: Boolean = false)
+ extends MyRDD(sc, numPartitions, dependencies, locations, tracker, indeterminate) {
+
+ // Allow doCheckpoint() on this RDD.
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+ Iterator.empty
+}
+
/**
* An RDD for passing to DAGScheduler. These RDDs will use the dependencies and
* preferredLocations (if any) that are passed to them. They are deliberately not executable
@@ -70,7 +84,8 @@ class MyRDD(
numPartitions: Int,
dependencies: List[Dependency[_]],
locations: Seq[Seq[String]] = Nil,
- @(transient @param) tracker: MapOutputTrackerMaster = null)
+ @(transient @param) tracker: MapOutputTrackerMaster = null,
+ indeterminate: Boolean = false)
extends RDD[(Int, Int)](sc, dependencies) with Serializable {
override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
@@ -80,6 +95,10 @@ class MyRDD(
override def index: Int = i
}).toArray
+ override protected def getOutputDeterministicLevel = {
+ if (indeterminate) DeterministicLevel.INDETERMINATE else super.getOutputDeterministicLevel
+ }
+
override def getPreferredLocations(partition: Partition): Seq[String] = {
if (locations.isDefinedAt(partition.index)) {
locations(partition.index)
@@ -1766,6 +1785,26 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(sc.parallelize(1 to 10, 2).count() === 10)
}
+ test("misbehaved accumulator should not impact other accumulators") {
+ val bad = new LongAccumulator {
+ override def merge(other: AccumulatorV2[java.lang.Long, java.lang.Long]): Unit = {
+ throw new DAGSchedulerSuiteDummyException
+ }
+ }
+ sc.register(bad, "bad")
+ val good = sc.longAccumulator("good")
+
+ sc.parallelize(1 to 10, 2).foreach { item =>
+ bad.add(1)
+ good.add(1)
+ }
+
+ // This is to ensure the `bad` accumulator did fail to update its value
+ assert(bad.value == 0L)
+ // Should be able to update the "good" accumulator
+ assert(good.value == 10L)
+ }
+
/**
* The job will be failed on first task throwing a DAGSchedulerSuiteDummyException.
* Any subsequent task WILL throw a legitimate java.lang.UnsupportedOperationException.
@@ -2146,6 +2185,58 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assertDataStructuresEmpty()
}
+ test("Trigger mapstage's job listener in submitMissingTasks") {
+ val rdd1 = new MyRDD(sc, 2, Nil)
+ val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2))
+ val rdd2 = new MyRDD(sc, 2, List(dep1), tracker = mapOutputTracker)
+ val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2))
+
+ val listener1 = new SimpleListener
+ val listener2 = new SimpleListener
+
+ submitMapStage(dep1, listener1)
+ submitMapStage(dep2, listener2)
+
+ // Complete the stage0.
+ assert(taskSets(0).stageId === 0)
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", rdd1.partitions.length)),
+ (Success, makeMapStatus("hostB", rdd1.partitions.length))))
+ assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ assert(listener1.results.size === 1)
+
+ // When attempting stage1, trigger a fetch failure.
+ assert(taskSets(1).stageId === 1)
+ complete(taskSets(1), Seq(
+ (Success, makeMapStatus("hostC", rdd2.partitions.length)),
+ (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null)))
+ scheduler.resubmitFailedStages()
+ // Stage1 listener should not have a result yet
+ assert(listener2.results.size === 0)
+
+ // Speculative task succeeded in stage1.
+ runEvent(makeCompletionEvent(
+ taskSets(1).tasks(1),
+ Success,
+ makeMapStatus("hostD", rdd2.partitions.length)))
+ // stage1 listener still should not have a result, though there's no missing partitions
+ // in it. Because stage1 has been failed and is not inside `runningStages` at this moment.
+ assert(listener2.results.size === 0)
+
+ // Stage0 should now be running as task set 2; make its task succeed
+ assert(taskSets(2).stageId === 0)
+ complete(taskSets(2), Seq(
+ (Success, makeMapStatus("hostC", rdd2.partitions.length))))
+ assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet ===
+ Set(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+
+ // After stage0 is finished, stage1 will be submitted and found there is no missing
+ // partitions in it. Then listener got triggered.
+ assert(listener2.results.size === 1)
+ assertDataStructuresEmpty()
+ }
+
/**
* In this test, we run a map stage where one of the executors fails but we still receive a
* "zombie" complete message from that executor. We want to make sure the stage is not reported
@@ -2343,7 +2434,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
runEvent(makeCompletionEvent(
taskSets(1).tasks(1), Success, makeMapStatus("hostA", 2)))
- // Both tasks in rddB should be resubmitted, because none of them has succeeded truely.
+ // task(stageId=1, stageAttemptId=1, partitionId=1) should be marked completed when
+ // task(stageId=1, stageAttemptId=0, partitionId=1) finished
+ // ideally we would verify that but no way to get into task scheduler to verify
+
+ // Both tasks in rddB should be resubmitted, because none of them has succeeded truly.
// Complete the task(stageId=1, stageAttemptId=1, partitionId=0) successfully.
// Task(stageId=1, stageAttemptId=1, partitionId=1) of this new active stage attempt
// is still running.
@@ -2352,19 +2447,21 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
runEvent(makeCompletionEvent(
taskSets(3).tasks(0), Success, makeMapStatus("hostB", 2)))
- // There should be no new attempt of stage submitted,
- // because task(stageId=1, stageAttempt=1, partitionId=1) is still running in
- // the current attempt (and hasn't completed successfully in any earlier attempts).
- assert(taskSets.size === 4)
+ // At this point there should be no active task set for stageId=1 and we need
+ // to resubmit because the output from (stageId=1, stageAttemptId=0, partitionId=1)
+ // was ignored due to executor failure
+ assert(taskSets.size === 5)
+ assert(taskSets(4).stageId === 1 && taskSets(4).stageAttemptId === 2
+ && taskSets(4).tasks.size === 1)
- // Complete task(stageId=1, stageAttempt=1, partitionId=1) successfully.
+ // Complete task(stageId=1, stageAttempt=2, partitionId=1) successfully.
runEvent(makeCompletionEvent(
- taskSets(3).tasks(1), Success, makeMapStatus("hostB", 2)))
+ taskSets(4).tasks(0), Success, makeMapStatus("hostB", 2)))
// Now the ResultStage should be submitted, because all of the tasks of rddB have
// completed successfully on alive executors.
- assert(taskSets.size === 5 && taskSets(4).tasks(0).isInstanceOf[ResultTask[_, _]])
- complete(taskSets(4), Seq(
+ assert(taskSets.size === 6 && taskSets(5).tasks(0).isInstanceOf[ResultTask[_, _]])
+ complete(taskSets(5), Seq(
(Success, 1),
(Success, 1)))
}
@@ -2399,6 +2496,152 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
}
}
+ test("SPARK-23207: retry all the succeeding stages when the map stage is indeterminate") {
+ val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true)
+
+ val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2))
+ val shuffleId1 = shuffleDep1.shuffleId
+ val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+
+ val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2))
+ val shuffleId2 = shuffleDep2.shuffleId
+ val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker)
+
+ submit(finalRdd, Array(0, 1))
+
+ // Finish the first shuffle map stage.
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))))
+ assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty))
+
+ // Finish the second shuffle map stage.
+ complete(taskSets(1), Seq(
+ (Success, makeMapStatus("hostC", 2)),
+ (Success, makeMapStatus("hostD", 2))))
+ assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty))
+
+ // The first task of the final stage failed with fetch failure
+ runEvent(makeCompletionEvent(
+ taskSets(2).tasks(0),
+ FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0, 0, "ignored"),
+ null))
+
+ val failedStages = scheduler.failedStages.toSeq
+ assert(failedStages.length == 2)
+ // Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry.
+ assert(failedStages.collect {
+ case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage
+ }.head.findMissingPartitions() == Seq(0))
+ // The result stage is still waiting for its 2 tasks to complete
+ assert(failedStages.collect {
+ case stage: ResultStage => stage
+ }.head.findMissingPartitions() == Seq(0, 1))
+
+ scheduler.resubmitFailedStages()
+
+ // The first task of the `shuffleMapRdd2` failed with fetch failure
+ runEvent(makeCompletionEvent(
+ taskSets(3).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0, 0, "ignored"),
+ null))
+
+ // The job should fail because Spark can't rollback the shuffle map stage.
+ assert(failure != null && failure.getMessage.contains("Spark cannot rollback"))
+ }
+
+ private def assertResultStageFailToRollback(mapRdd: MyRDD): Unit = {
+ val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2))
+ val shuffleId = shuffleDep.shuffleId
+ val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
+
+ submit(finalRdd, Array(0, 1))
+
+ completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, numShufflePartitions = 2)
+ assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty))
+
+ // Finish the first task of the result stage
+ runEvent(makeCompletionEvent(
+ taskSets.last.tasks(0), Success, 42,
+ Seq.empty, createFakeTaskInfoWithId(0)))
+
+ // Fail the second task with FetchFailed.
+ runEvent(makeCompletionEvent(
+ taskSets.last.tasks(1),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+ null))
+
+ // The job should fail because Spark can't rollback the result stage.
+ assert(failure != null && failure.getMessage.contains("Spark cannot rollback"))
+ }
+
+ test("SPARK-23207: cannot rollback a result stage") {
+ val shuffleMapRdd = new MyRDD(sc, 2, Nil, indeterminate = true)
+ assertResultStageFailToRollback(shuffleMapRdd)
+ }
+
+ test("SPARK-23207: local checkpoint fail to rollback (checkpointed before)") {
+ val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true)
+ shuffleMapRdd.localCheckpoint()
+ shuffleMapRdd.doCheckpoint()
+ assertResultStageFailToRollback(shuffleMapRdd)
+ }
+
+ test("SPARK-23207: local checkpoint fail to rollback (checkpointing now)") {
+ val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true)
+ shuffleMapRdd.localCheckpoint()
+ assertResultStageFailToRollback(shuffleMapRdd)
+ }
+
+ private def assertResultStageNotRollbacked(mapRdd: MyRDD): Unit = {
+ val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2))
+ val shuffleId = shuffleDep.shuffleId
+ val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
+
+ submit(finalRdd, Array(0, 1))
+
+ completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, numShufflePartitions = 2)
+ assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty))
+
+ // Finish the first task of the result stage
+ runEvent(makeCompletionEvent(
+ taskSets.last.tasks(0), Success, 42,
+ Seq.empty, createFakeTaskInfoWithId(0)))
+
+ // Fail the second task with FetchFailed.
+ runEvent(makeCompletionEvent(
+ taskSets.last.tasks(1),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+ null))
+
+ assert(failure == null, "job should not fail")
+ val failedStages = scheduler.failedStages.toSeq
+ assert(failedStages.length == 2)
+ // Shuffle blocks of "hostA" is lost, so first task of the `shuffleMapRdd2` needs to retry.
+ assert(failedStages.collect {
+ case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId => stage
+ }.head.findMissingPartitions() == Seq(0))
+ // The first task of result stage remains completed.
+ assert(failedStages.collect {
+ case stage: ResultStage => stage
+ }.head.findMissingPartitions() == Seq(1))
+ }
+
+ test("SPARK-23207: reliable checkpoint can avoid rollback (checkpointed before)") {
+ sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath)
+ val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true)
+ shuffleMapRdd.checkpoint()
+ shuffleMapRdd.doCheckpoint()
+ assertResultStageNotRollbacked(shuffleMapRdd)
+ }
+
+ test("SPARK-23207: reliable checkpoint fail to rollback (checkpointing now)") {
+ sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath)
+ val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true)
+ shuffleMapRdd.checkpoint()
+ assertResultStageFailToRollback(shuffleMapRdd)
+ }
+
/**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
index 03b190390249..158c9eb75f2b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
@@ -35,6 +35,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark._
import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapRedCommitProtocol, SparkHadoopWriterUtils}
import org.apache.spark.rdd.{FakeOutputCommitter, RDD}
+import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.{ThreadUtils, Utils}
/**
@@ -153,7 +154,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
test("Job should not complete if all commits are denied") {
// Create a mock OutputCommitCoordinator that denies all attempts to commit
doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit(
- Matchers.any(), Matchers.any(), Matchers.any())
+ Matchers.any(), Matchers.any(), Matchers.any(), Matchers.any())
val rdd: RDD[Int] = sc.parallelize(Seq(1), 1)
def resultHandler(x: Int, y: Unit): Unit = {}
val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd,
@@ -169,45 +170,106 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") {
val stage: Int = 1
+ val stageAttempt: Int = 1
val partition: Int = 2
val authorizedCommitter: Int = 3
val nonAuthorizedCommitter: Int = 100
outputCommitCoordinator.stageStart(stage, maxPartitionId = 2)
- assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter))
- assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter))
+ assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, authorizedCommitter))
+ assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition,
+ nonAuthorizedCommitter))
// The non-authorized committer fails
- outputCommitCoordinator.taskCompleted(
- stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test"))
+ outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition,
+ attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test"))
// New tasks should still not be able to commit because the authorized committer has not failed
- assert(
- !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1))
+ assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition,
+ nonAuthorizedCommitter + 1))
// The authorized committer now fails, clearing the lock
- outputCommitCoordinator.taskCompleted(
- stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test"))
+ outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition,
+ attemptNumber = authorizedCommitter, reason = TaskKilled("test"))
// A new task should now be allowed to become the authorized committer
- assert(
- outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2))
+ assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition,
+ nonAuthorizedCommitter + 2))
// There can only be one authorized committer
- assert(
- !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3))
- }
-
- test("Duplicate calls to canCommit from the authorized committer gets idempotent responses.") {
- val rdd = sc.parallelize(Seq(1), 1)
- sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).callCanCommitMultipleTimes _,
- 0 until rdd.partitions.size)
+ assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition,
+ nonAuthorizedCommitter + 3))
}
test("SPARK-19631: Do not allow failed attempts to be authorized for committing") {
val stage: Int = 1
+ val stageAttempt: Int = 1
val partition: Int = 1
val failedAttempt: Int = 0
outputCommitCoordinator.stageStart(stage, maxPartitionId = 1)
- outputCommitCoordinator.taskCompleted(stage, partition, attemptNumber = failedAttempt,
+ outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition,
+ attemptNumber = failedAttempt,
reason = ExecutorLostFailure("0", exitCausedByApp = true, None))
- assert(!outputCommitCoordinator.canCommit(stage, partition, failedAttempt))
- assert(outputCommitCoordinator.canCommit(stage, partition, failedAttempt + 1))
+ assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt))
+ assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt + 1))
+ }
+
+ test("SPARK-24589: Differentiate tasks from different stage attempts") {
+ var stage = 1
+ val taskAttempt = 1
+ val partition = 1
+
+ outputCommitCoordinator.stageStart(stage, maxPartitionId = 1)
+ assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt))
+ assert(!outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt))
+
+ // Fail the task in the first attempt, the task in the second attempt should succeed.
+ stage += 1
+ outputCommitCoordinator.stageStart(stage, maxPartitionId = 1)
+ outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt,
+ ExecutorLostFailure("0", exitCausedByApp = true, None))
+ assert(!outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt))
+ assert(outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt))
+
+ // Commit the 1st attempt, fail the 2nd attempt, make sure 3rd attempt cannot commit,
+ // then fail the 1st attempt and make sure the 4th one can commit again.
+ stage += 1
+ outputCommitCoordinator.stageStart(stage, maxPartitionId = 1)
+ assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt))
+ outputCommitCoordinator.taskCompleted(stage, 2, partition, taskAttempt,
+ ExecutorLostFailure("0", exitCausedByApp = true, None))
+ assert(!outputCommitCoordinator.canCommit(stage, 3, partition, taskAttempt))
+ outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt,
+ ExecutorLostFailure("0", exitCausedByApp = true, None))
+ assert(outputCommitCoordinator.canCommit(stage, 4, partition, taskAttempt))
+ }
+
+ test("SPARK-24589: Make sure stage state is cleaned up") {
+ // Normal application without stage failures.
+ sc.parallelize(1 to 100, 100)
+ .map { i => (i % 10, i) }
+ .reduceByKey(_ + _)
+ .collect()
+
+ assert(sc.dagScheduler.outputCommitCoordinator.isEmpty)
+
+ // Force failures in a few tasks so that a stage is retried. Collect the ID of the failing
+ // stage so that we can check the state of the output committer.
+ val retriedStage = sc.parallelize(1 to 100, 10)
+ .map { i => (i % 10, i) }
+ .reduceByKey { case (_, _) =>
+ val ctx = TaskContext.get()
+ if (ctx.stageAttemptNumber() == 0) {
+ throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 1, 1, 1,
+ new Exception("Failure for test."))
+ } else {
+ ctx.stageId()
+ }
+ }
+ .collect()
+ .map { case (k, v) => v }
+ .toSet
+
+ assert(retriedStage.size === 1)
+ assert(sc.dagScheduler.outputCommitCoordinator.isEmpty)
+ verify(sc.env.outputCommitCoordinator, times(2))
+ .stageStart(Matchers.eq(retriedStage.head), Matchers.any())
+ verify(sc.env.outputCommitCoordinator).stageEnd(Matchers.eq(retriedStage.head))
}
}
@@ -243,16 +305,6 @@ private case class OutputCommitFunctions(tempDirPath: String) {
if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter)
}
- // Receiver should be idempotent for AskPermissionToCommitOutput
- def callCanCommitMultipleTimes(iter: Iterator[Int]): Unit = {
- val ctx = TaskContext.get()
- val canCommit1 = SparkEnv.get.outputCommitCoordinator
- .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber())
- val canCommit2 = SparkEnv.get.outputCommitCoordinator
- .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber())
- assert(canCommit1 && canCommit2)
- }
-
private def runCommitWithProvidedCommitter(
ctx: TaskContext,
iter: Iterator[Int],
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
index 75ea409e16b4..c783f4665afd 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
@@ -51,6 +51,9 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa
var taskScheduler: TestTaskScheduler = null
var scheduler: DAGScheduler = null
var backend: T = _
+ // Even though the tests aren't doing much, occassionally we see flakiness from pauses over
+ // a second (probably from GC?) so we leave a long timeout in here
+ val duration = Duration(10, SECONDS)
override def beforeEach(): Unit = {
if (taskScheduler != null) {
@@ -398,7 +401,8 @@ private[spark] abstract class MockBackend(
// get the task now, since that requires a lock on TaskSchedulerImpl, to prevent individual
// tests from introducing a race if they need it.
val newTasks = newTaskDescriptions.map { taskDescription =>
- val taskSet = taskScheduler.taskIdToTaskSetManager(taskDescription.taskId).taskSet
+ val taskSet =
+ Option(taskScheduler.taskIdToTaskSetManager.get(taskDescription.taskId).taskSet).get
val task = taskSet.tasks(taskDescription.index)
(taskDescription, task)
}
@@ -536,7 +540,6 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor
}
withBackend(runBackend _) {
val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray)
- val duration = Duration(1, SECONDS)
awaitJobTermination(jobFuture, duration)
}
assert(results === (0 until 10).map { _ -> 42 }.toMap)
@@ -589,7 +592,6 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor
}
withBackend(runBackend _) {
val jobFuture = submit(d, (0 until 30).toArray)
- val duration = Duration(1, SECONDS)
awaitJobTermination(jobFuture, duration)
}
assert(results === (0 until 30).map { idx => idx -> (4321 + idx) }.toMap)
@@ -631,7 +633,6 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor
}
withBackend(runBackend _) {
val jobFuture = submit(shuffledRdd, (0 until 10).toArray)
- val duration = Duration(1, SECONDS)
awaitJobTermination(jobFuture, duration)
}
assertDataStructuresEmpty()
@@ -646,7 +647,6 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor
}
withBackend(runBackend _) {
val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray)
- val duration = Duration(1, SECONDS)
awaitJobTermination(jobFuture, duration)
assert(failure.getMessage.contains("test task failure"))
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 1beb36afa95f..6ffd1e84f7ad 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.scheduler
+import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.Semaphore
import scala.collection.JavaConverters._
@@ -48,7 +49,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
bus.metrics.metricRegistry.counter(s"queue.$SHARED_QUEUE.numDroppedEvents").getCount
}
- private def queueSize(bus: LiveListenerBus): Int = {
+ private def sharedQueueSize(bus: LiveListenerBus): Int = {
bus.metrics.metricRegistry.getGauges().get(s"queue.$SHARED_QUEUE.size").getValue()
.asInstanceOf[Int]
}
@@ -73,12 +74,11 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
val conf = new SparkConf()
val counter = new BasicJobCounter
val bus = new LiveListenerBus(conf)
- bus.addToSharedQueue(counter)
// Metrics are initially empty.
assert(bus.metrics.numEventsPosted.getCount === 0)
assert(numDroppedEvents(bus) === 0)
- assert(queueSize(bus) === 0)
+ assert(bus.queuedEvents.size === 0)
assert(eventProcessingTimeCount(bus) === 0)
// Post five events:
@@ -87,7 +87,10 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
// Five messages should be marked as received and queued, but no messages should be posted to
// listeners yet because the the listener bus hasn't been started.
assert(bus.metrics.numEventsPosted.getCount === 5)
- assert(queueSize(bus) === 5)
+ assert(bus.queuedEvents.size === 5)
+
+ // Add the counter to the bus after messages have been queued for later delivery.
+ bus.addToSharedQueue(counter)
assert(counter.count === 0)
// Starting listener bus should flush all buffered events
@@ -95,9 +98,12 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
Mockito.verify(mockMetricsSystem).registerSource(bus.metrics)
bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(counter.count === 5)
- assert(queueSize(bus) === 0)
+ assert(sharedQueueSize(bus) === 0)
assert(eventProcessingTimeCount(bus) === 5)
+ // After the bus is started, there should be no more queued events.
+ assert(bus.queuedEvents === null)
+
// After listener bus has stopped, posting events should not increment counter
bus.stop()
(1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) }
@@ -188,18 +194,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
// Post a message to the listener bus and wait for processing to begin:
bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded))
listenerStarted.acquire()
- assert(queueSize(bus) === 0)
+ assert(sharedQueueSize(bus) === 0)
assert(numDroppedEvents(bus) === 0)
// If we post an additional message then it should remain in the queue because the listener is
// busy processing the first event:
bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded))
- assert(queueSize(bus) === 1)
+ assert(sharedQueueSize(bus) === 1)
assert(numDroppedEvents(bus) === 0)
// The queue is now full, so any additional events posted to the listener will be dropped:
bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded))
- assert(queueSize(bus) === 1)
+ assert(sharedQueueSize(bus) === 1)
assert(numDroppedEvents(bus) === 1)
// Allow the the remaining events to be processed so we can stop the listener bus:
@@ -289,10 +295,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
val listener = new SaveStageAndTaskInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
- // just to make sure some of the tasks take a noticeable amount of time
+ // just to make sure some of the tasks and their deserialization take a noticeable
+ // amount of time
+ val slowDeserializable = new SlowDeserializable
val w = { i: Int =>
if (i == 0) {
Thread.sleep(100)
+ slowDeserializable.use()
}
i
}
@@ -480,6 +489,48 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
assert(bus.findListenersByClass[BasicJobCounter]().isEmpty)
}
+ Seq(true, false).foreach { throwInterruptedException =>
+ val suffix = if (throwInterruptedException) "throw interrupt" else "set Thread interrupted"
+ test(s"interrupt within listener is handled correctly: $suffix") {
+ val conf = new SparkConf(false)
+ .set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 5)
+ val bus = new LiveListenerBus(conf)
+ val counter1 = new BasicJobCounter()
+ val counter2 = new BasicJobCounter()
+ val interruptingListener1 = new InterruptingListener(throwInterruptedException)
+ val interruptingListener2 = new InterruptingListener(throwInterruptedException)
+ bus.addToSharedQueue(counter1)
+ bus.addToSharedQueue(interruptingListener1)
+ bus.addToStatusQueue(counter2)
+ bus.addToEventLogQueue(interruptingListener2)
+ assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE, EVENT_LOG_QUEUE))
+ assert(bus.findListenersByClass[BasicJobCounter]().size === 2)
+ assert(bus.findListenersByClass[InterruptingListener]().size === 2)
+
+ bus.start(mockSparkContext, mockMetricsSystem)
+
+ // after we post one event, both interrupting listeners should get removed, and the
+ // event log queue should be removed
+ bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded))
+ bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE))
+ assert(bus.findListenersByClass[BasicJobCounter]().size === 2)
+ assert(bus.findListenersByClass[InterruptingListener]().size === 0)
+ assert(counter1.count === 1)
+ assert(counter2.count === 1)
+
+ // posting more events should be fine, they'll just get processed from the OK queue.
+ (0 until 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) }
+ bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(counter1.count === 6)
+ assert(counter2.count === 6)
+
+ // Make sure stopping works -- this requires putting a poison pill in all active queues, which
+ // would fail if our interrupted queue was still active, as its queue would be full.
+ bus.stop()
+ }
+ }
+
/**
* Assert that the given list of numbers has an average that is greater than zero.
*/
@@ -538,6 +589,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception }
}
+ /**
+ * A simple listener that interrupts on job end.
+ */
+ private class InterruptingListener(val throwInterruptedException: Boolean) extends SparkListener {
+ override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+ if (throwInterruptedException) {
+ throw new InterruptedException("got interrupted")
+ } else {
+ Thread.currentThread().interrupt()
+ }
+ }
+ }
}
// These classes can't be declared inside of the SparkListenerSuite class because we don't want
@@ -578,3 +641,12 @@ private class FirehoseListenerThatAcceptsSparkConf(conf: SparkConf) extends Spar
case _ =>
}
}
+
+private class SlowDeserializable extends Externalizable {
+
+ override def writeExternal(out: ObjectOutput): Unit = { }
+
+ override def readExternal(in: ObjectInput): Unit = Thread.sleep(1)
+
+ def use(): Unit = { }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index a1d9085fa085..aa9c36c0aaac 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.JvmSource
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
+import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._
class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {
@@ -158,6 +159,30 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
}
+ test("TaskContext.stageAttemptNumber getter") {
+ sc = new SparkContext("local[1,2]", "test")
+
+ // Check stageAttemptNumbers are 0 for initial stage
+ val stageAttemptNumbers = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ =>
+ Seq(TaskContext.get().stageAttemptNumber()).iterator
+ }.collect()
+ assert(stageAttemptNumbers.toSet === Set(0))
+
+ // Check stageAttemptNumbers that are resubmitted when tasks have FetchFailedException
+ val stageAttemptNumbersWithFailedStage =
+ sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ =>
+ val stageAttemptNumber = TaskContext.get().stageAttemptNumber()
+ if (stageAttemptNumber < 2) {
+ // Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception
+ // will only trigger task resubmission in the same stage.
+ throw new FetchFailedException(null, 0, 0, 0, "Fake")
+ }
+ Seq(stageAttemptNumber).iterator
+ }.collect()
+
+ assert(stageAttemptNumbersWithFailedStage.toSet === Set(2))
+ }
+
test("accumulators are updated on exception failures") {
// This means use 1 core and 4 max task failures
sc = new SparkContext("local[1,4]", "test")
@@ -190,7 +215,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
// accumulator updates from it.
val taskMetrics = TaskMetrics.empty
val task = new Task[Int](0, 0, 0) {
- context = new TaskContextImpl(0, 0, 0L, 0,
+ context = new TaskContextImpl(0, 0, 0, 0L, 0,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
new Properties,
SparkEnv.get.metricsSystem,
@@ -213,7 +238,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
// accumulator updates from it.
val taskMetrics = TaskMetrics.registered
val task = new Task[Int](0, 0, 0) {
- context = new TaskContextImpl(0, 0, 0L, 0,
+ context = new TaskContextImpl(0, 0, 0, 0L, 0,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
new Properties,
SparkEnv.get.metricsSystem,
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 6003899bb7be..6809d918495f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -62,7 +62,6 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
}
override def afterEach(): Unit = {
- super.afterEach()
if (taskScheduler != null) {
taskScheduler.stop()
taskScheduler = null
@@ -71,6 +70,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
dagScheduler.stop()
dagScheduler = null
}
+ super.afterEach()
}
def setupScheduler(confs: (String, String)*): TaskSchedulerImpl = {
@@ -196,28 +196,39 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
// Even if one of the task sets has not-serializable tasks, the other task set should
// still be processed without error
taskScheduler.submitTasks(FakeTask.createTaskSet(1))
- taskScheduler.submitTasks(taskSet)
+ val taskSet2 = new TaskSet(
+ Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 1, 0, 0, null)
+ taskScheduler.submitTasks(taskSet2)
taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten
assert(taskDescriptions.map(_.executorId) === Seq("executor0"))
}
- test("refuse to schedule concurrent attempts for the same stage (SPARK-8103)") {
+ test("concurrent attempts for the same stage only have one active taskset") {
val taskScheduler = setupScheduler()
+ def isTasksetZombie(taskset: TaskSet): Boolean = {
+ taskScheduler.taskSetManagerForAttempt(taskset.stageId, taskset.stageAttemptId).get.isZombie
+ }
+
val attempt1 = FakeTask.createTaskSet(1, 0)
- val attempt2 = FakeTask.createTaskSet(1, 1)
taskScheduler.submitTasks(attempt1)
- intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) }
+ // The first submitted taskset is active
+ assert(!isTasksetZombie(attempt1))
- // OK to submit multiple if previous attempts are all zombie
- taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId)
- .get.isZombie = true
+ val attempt2 = FakeTask.createTaskSet(1, 1)
taskScheduler.submitTasks(attempt2)
+ // The first submitted taskset is zombie now
+ assert(isTasksetZombie(attempt1))
+ // The newly submitted taskset is active
+ assert(!isTasksetZombie(attempt2))
+
val attempt3 = FakeTask.createTaskSet(1, 2)
- intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) }
- taskScheduler.taskSetManagerForAttempt(attempt2.stageId, attempt2.stageAttemptId)
- .get.isZombie = true
taskScheduler.submitTasks(attempt3)
- assert(!failedTaskSet)
+ // The first submitted taskset remains zombie
+ assert(isTasksetZombie(attempt1))
+ // The second submitted taskset is zombie now
+ assert(isTasksetZombie(attempt2))
+ // The newly submitted taskset is active
+ assert(!isTasksetZombie(attempt3))
}
test("don't schedule more tasks after a taskset is zombie") {
@@ -247,7 +258,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
taskScheduler.submitTasks(attempt2)
val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
assert(1 === taskDescriptions3.length)
- val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get
+ val mgr = Option(taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId)).get
assert(mgr.taskSet.stageAttemptId === 1)
assert(!failedTaskSet)
}
@@ -285,7 +296,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
assert(10 === taskDescriptions3.length)
taskDescriptions3.foreach { task =>
- val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get
+ val mgr = Option(taskScheduler.taskIdToTaskSetManager.get(task.taskId)).get
assert(mgr.taskSet.stageAttemptId === 1)
}
assert(!failedTaskSet)
@@ -723,7 +734,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
// only schedule one task because of locality
assert(taskDescs.size === 1)
- val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescs(0).taskId).get
+ val mgr = Option(taskScheduler.taskIdToTaskSetManager.get(taskDescs(0).taskId)).get
assert(mgr.myLocalityLevels.toSet === Set(TaskLocality.NODE_LOCAL, TaskLocality.ANY))
// we should know about both executors, even though we only scheduled tasks on one of them
assert(taskScheduler.getExecutorsAliveOnHost("host0") === Some(Set("executor0")))
@@ -917,4 +928,130 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
taskScheduler.initialize(new FakeSchedulerBackend)
}
}
+
+ test("SPARK-23433/25250 Completions in zombie tasksets update status of non-zombie taskset") {
+ val taskScheduler = setupSchedulerWithMockTaskSetBlacklist()
+ val valueSer = SparkEnv.get.serializer.newInstance()
+
+ def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = {
+ val indexInTsm = tsm.partitionToIndex(partition)
+ val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head
+ val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
+ tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result)
+ }
+
+ // Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt,
+ // two times, so we have three TaskSetManagers(2 zombie, 1 active) for one stage. (For this
+ // to really happen, you'd need the previous stage to also get restarted, and then succeed,
+ // in between each attempt, but that happens outside what we're mocking here.)
+ val zombieAttempts = (0 until 2).map { stageAttempt =>
+ val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt)
+ taskScheduler.submitTasks(attempt)
+ val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get
+ val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
+ taskScheduler.resourceOffers(offers)
+ assert(tsm.runningTasks === 10)
+ // fail attempt
+ tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED,
+ FetchFailed(null, 0, 0, 0, "fetch failed"))
+ // the attempt is a zombie, but the tasks are still running (this could be true even if
+ // we actively killed those tasks, as killing is best-effort)
+ assert(tsm.isZombie)
+ assert(tsm.runningTasks === 9)
+ tsm
+ }
+ // we've now got 2 zombie attempts, each with 9 tasks still running. And there's no active
+ // attempt exists in taskScheduler by now.
+
+ // finish partition 1,2 by completing the tasks before a new attempt for the same stage submit.
+ // This is possible since the behaviour of submitting new attempt and handling successful task
+ // is from two different threads, which are "task-result-getter" and "dag-scheduler-event-loop"
+ // separately.
+ (0 until 2).foreach { i =>
+ completeTaskSuccessfully(zombieAttempts(i), i + 1)
+ assert(taskScheduler.stageIdToFinishedPartitions(0).contains(i + 1))
+ }
+
+ // Submit the 3rd attempt still with 10 tasks, this happens due to the race between thread
+ // "task-result-getter" and "dag-scheduler-event-loop", where a TaskSet gets submitted with
+ // already completed tasks. And this time with insufficient resources so not all tasks are
+ // active.
+ val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2)
+ taskScheduler.submitTasks(finalAttempt)
+ val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get
+ // Though finalTSM gets submitted with 10 tasks, the call to taskScheduler.submitTasks should
+ // realize that 2 tasks have already completed, and mark them appropriately, so it won't launch
+ // any duplicate tasks later (SPARK-25250).
+ (0 until 2).map(_ + 1).foreach { partitionId =>
+ val index = finalTsm.partitionToIndex(partitionId)
+ assert(finalTsm.successful(index))
+ }
+
+ val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
+ val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task =>
+ finalAttempt.tasks(task.index).partitionId
+ }.toSet
+ assert(finalTsm.runningTasks === 5)
+ assert(!finalTsm.isZombie)
+
+ // We continually simulate late completions from our zombie tasksets(but this time, there's one
+ // active attempt exists in taskScheduler), corresponding to all the pending partitions in our
+ // final attempt. This means we're only waiting on the tasks we've already launched.
+ val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions)
+ finalAttemptPendingPartitions.foreach { partition =>
+ completeTaskSuccessfully(zombieAttempts(0), partition)
+ assert(taskScheduler.stageIdToFinishedPartitions(0).contains(partition))
+ }
+
+ // If there is another resource offer, we shouldn't run anything. Though our final attempt
+ // used to have pending tasks, now those tasks have been completed by zombie attempts. The
+ // remaining tasks to compute are already active in the non-zombie attempt.
+ assert(
+ taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty)
+
+ val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted
+
+ // finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be
+ // marked as zombie.
+ // for each of the remaining tasks, find the tasksets with an active copy of the task, and
+ // finish the task.
+ remainingTasks.foreach { partition =>
+ val tsm = if (partition == 0) {
+ // we failed this task on both zombie attempts, this one is only present in the latest
+ // taskset
+ finalTsm
+ } else {
+ // should be active in every taskset. We choose a zombie taskset just to make sure that
+ // we transition the active taskset correctly even if the final completion comes
+ // from a zombie.
+ zombieAttempts(partition % 2)
+ }
+ completeTaskSuccessfully(tsm, partition)
+ }
+
+ assert(finalTsm.isZombie)
+
+ // no taskset has completed all of its tasks, so no updates to the blacklist tracker yet
+ verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject())
+
+ // finally, lets complete all the tasks. We simulate failures in attempt 1, but everything
+ // else succeeds, to make sure we get the right updates to the blacklist in all cases.
+ (zombieAttempts ++ Seq(finalTsm)).foreach { tsm =>
+ val stageAttempt = tsm.taskSet.stageAttemptId
+ tsm.runningTasksSet.foreach { index =>
+ if (stageAttempt == 1) {
+ tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost)
+ } else {
+ val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
+ tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result)
+ }
+ }
+
+ // we update the blacklist for the stage attempts with all successful tasks. Even though
+ // some tasksets had failures, we still consider them all successful from a blacklisting
+ // perspective, as the failures weren't from a problem w/ the tasks themselves.
+ verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject())
+ }
+ assert(taskScheduler.stageIdToFinishedPartitions.isEmpty)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 2ce81ae27daf..d75c24500344 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -178,12 +178,12 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
}
override def afterEach(): Unit = {
- super.afterEach()
if (sched != null) {
sched.dagScheduler.stop()
sched.stop()
sched = null
}
+ super.afterEach()
}
@@ -1362,6 +1362,167 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
assert(taskOption4.get.addedJars === addedJarsMidTaskSet)
}
+ test("[SPARK-24677] Avoid NoSuchElementException from MedianHeap") {
+ val conf = new SparkConf().set("spark.speculation", "true")
+ sc = new SparkContext("local", "test", conf)
+ // Set the speculation multiplier to be 0 so speculative tasks are launched immediately
+ sc.conf.set("spark.speculation.multiplier", "0.0")
+ sc.conf.set("spark.speculation.quantile", "0.1")
+ sc.conf.set("spark.speculation", "true")
+
+ sched = new FakeTaskScheduler(sc)
+ sched.initialize(new FakeSchedulerBackend())
+
+ val dagScheduler = new FakeDAGScheduler(sc, sched)
+ sched.setDAGScheduler(dagScheduler)
+
+ val taskSet1 = FakeTask.createTaskSet(10)
+ val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet1.tasks.map { task =>
+ task.metrics.internalAccums
+ }
+
+ sched.submitTasks(taskSet1)
+ sched.resourceOffers(
+ (0 until 10).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) })
+
+ val taskSetManager1 = sched.taskSetManagerForAttempt(0, 0).get
+
+ // fail fetch
+ taskSetManager1.handleFailedTask(
+ taskSetManager1.taskAttempts.head.head.taskId, TaskState.FAILED,
+ FetchFailed(null, 0, 0, 0, "fetch failed"))
+
+ assert(taskSetManager1.isZombie)
+ assert(taskSetManager1.runningTasks === 9)
+
+ val taskSet2 = FakeTask.createTaskSet(10, stageAttemptId = 1)
+ sched.submitTasks(taskSet2)
+ sched.resourceOffers(
+ (11 until 20).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) })
+
+ // Complete the 2 tasks and leave 8 task in running
+ for (id <- Set(0, 1)) {
+ taskSetManager1.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id)))
+ assert(sched.endedTasks(id) === Success)
+ }
+
+ val taskSetManager2 = sched.taskSetManagerForAttempt(0, 1).get
+ assert(!taskSetManager2.successfulTaskDurations.isEmpty())
+ taskSetManager2.checkSpeculatableTasks(0)
+ }
+
+
+ test("SPARK-24755 Executor loss can cause task to not be resubmitted") {
+ val conf = new SparkConf().set("spark.speculation", "true")
+ sc = new SparkContext("local", "test", conf)
+ // Set the speculation multiplier to be 0 so speculative tasks are launched immediately
+ sc.conf.set("spark.speculation.multiplier", "0.0")
+
+ sc.conf.set("spark.speculation.quantile", "0.5")
+ sc.conf.set("spark.speculation", "true")
+
+ var killTaskCalled = false
+ sched = new FakeTaskScheduler(sc, ("exec1", "host1"),
+ ("exec2", "host2"), ("exec3", "host3"))
+ sched.initialize(new FakeSchedulerBackend() {
+ override def killTask(
+ taskId: Long,
+ executorId: String,
+ interruptThread: Boolean,
+ reason: String): Unit = {
+ // Check the only one killTask event in this case, which triggered by
+ // task 2.1 completed.
+ assert(taskId === 2)
+ assert(executorId === "exec3")
+ assert(interruptThread)
+ assert(reason === "another attempt succeeded")
+ killTaskCalled = true
+ }
+ })
+
+ // Keep track of the index of tasks that are resubmitted,
+ // so that the test can check that task is resubmitted correctly
+ var resubmittedTasks = new mutable.HashSet[Int]
+ val dagScheduler = new FakeDAGScheduler(sc, sched) {
+ override def taskEnded(
+ task: Task[_],
+ reason: TaskEndReason,
+ result: Any,
+ accumUpdates: Seq[AccumulatorV2[_, _]],
+ taskInfo: TaskInfo): Unit = {
+ super.taskEnded(task, reason, result, accumUpdates, taskInfo)
+ reason match {
+ case Resubmitted => resubmittedTasks += taskInfo.index
+ case _ =>
+ }
+ }
+ }
+ sched.dagScheduler.stop()
+ sched.setDAGScheduler(dagScheduler)
+
+ val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0,
+ Seq(TaskLocation("host1", "exec1")),
+ Seq(TaskLocation("host1", "exec1")),
+ Seq(TaskLocation("host3", "exec3")),
+ Seq(TaskLocation("host2", "exec2")))
+
+ val clock = new ManualClock()
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
+ val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task =>
+ task.metrics.internalAccums
+ }
+ // Offer resources for 4 tasks to start
+ for ((exec, host) <- Seq(
+ "exec1" -> "host1",
+ "exec1" -> "host1",
+ "exec3" -> "host3",
+ "exec2" -> "host2")) {
+ val taskOption = manager.resourceOffer(exec, host, NO_PREF)
+ assert(taskOption.isDefined)
+ val task = taskOption.get
+ assert(task.executorId === exec)
+ // Add an extra assert to make sure task 2.0 is running on exec3
+ if (task.index == 2) {
+ assert(task.attemptNumber === 0)
+ assert(task.executorId === "exec3")
+ }
+ }
+ assert(sched.startedTasks.toSet === Set(0, 1, 2, 3))
+ clock.advance(1)
+ // Complete the 2 tasks and leave 2 task in running
+ for (id <- Set(0, 1)) {
+ manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id)))
+ assert(sched.endedTasks(id) === Success)
+ }
+
+ // checkSpeculatableTasks checks that the task runtime is greater than the threshold for
+ // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for
+ // > 0ms, so advance the clock by 1ms here.
+ clock.advance(1)
+ assert(manager.checkSpeculatableTasks(0))
+ assert(sched.speculativeTasks.toSet === Set(2, 3))
+
+ // Offer resource to start the speculative attempt for the running task 2.0
+ val taskOption = manager.resourceOffer("exec2", "host2", ANY)
+ assert(taskOption.isDefined)
+ val task4 = taskOption.get
+ assert(task4.index === 2)
+ assert(task4.taskId === 4)
+ assert(task4.executorId === "exec2")
+ assert(task4.attemptNumber === 1)
+ // Complete the speculative attempt for the running task
+ manager.handleSuccessfulTask(4, createTaskResult(2, accumUpdatesByTask(2)))
+ // Make sure schedBackend.killTask(2, "exec3", true, "another attempt succeeded") gets called
+ assert(killTaskCalled)
+
+ assert(resubmittedTasks.isEmpty)
+ // Host 2 Losts, meaning we lost the map output task4
+ manager.executorLost("exec2", "host2", SlaveLost())
+ // Make sure that task with index 2 is re-submitted
+ assert(resubmittedTasks.contains(2))
+
+ }
+
private def createTaskResult(
id: Int,
accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = {
diff --git a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
new file mode 100644
index 000000000000..e57cb701b628
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.security
+
+import java.io.Closeable
+import java.net._
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.internal.config._
+import org.apache.spark.util.Utils
+
+class SocketAuthHelperSuite extends SparkFunSuite {
+
+ private val conf = new SparkConf()
+ private val authHelper = new SocketAuthHelper(conf)
+
+ test("successful auth") {
+ Utils.tryWithResource(new ServerThread()) { server =>
+ Utils.tryWithResource(server.createClient()) { client =>
+ authHelper.authToServer(client)
+ server.close()
+ server.join()
+ assert(server.error == null)
+ assert(server.authenticated)
+ }
+ }
+ }
+
+ test("failed auth") {
+ Utils.tryWithResource(new ServerThread()) { server =>
+ Utils.tryWithResource(server.createClient()) { client =>
+ val badHelper = new SocketAuthHelper(new SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128))
+ intercept[IllegalArgumentException] {
+ badHelper.authToServer(client)
+ }
+ server.close()
+ server.join()
+ assert(server.error != null)
+ assert(!server.authenticated)
+ }
+ }
+ }
+
+ private class ServerThread extends Thread with Closeable {
+
+ private val ss = new ServerSocket()
+ ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0))
+
+ @volatile var error: Exception = _
+ @volatile var authenticated = false
+
+ setDaemon(true)
+ start()
+
+ def createClient(): Socket = {
+ new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort())
+ }
+
+ override def run(): Unit = {
+ var clientConn: Socket = null
+ try {
+ clientConn = ss.accept()
+ authHelper.authClient(clientConn)
+ authenticated = true
+ } catch {
+ case e: Exception =>
+ error = e
+ } finally {
+ Option(clientConn).foreach(_.close())
+ }
+ }
+
+ override def close(): Unit = {
+ try {
+ ss.close()
+ } finally {
+ interrupt()
+ }
+ }
+
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index fc78655bf52e..1234a8082c09 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.serializer
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream}
+import java.nio.ByteBuffer
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -477,6 +478,17 @@ class KryoSerializerAutoResetDisabledSuite extends SparkFunSuite with SharedSpar
deserializationStream.close()
assert(serInstance.deserialize[Any](helloHello) === ((hello, hello)))
}
+
+ test("SPARK-25786: ByteBuffer.array -- UnsupportedOperationException") {
+ val serInstance = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance]
+ val obj = "UnsupportedOperationException"
+ val serObj = serInstance.serialize(obj)
+ val byteBuffer = ByteBuffer.allocateDirect(serObj.array().length)
+ byteBuffer.put(serObj.array())
+ byteBuffer.flip()
+ assert(serInstance.deserialize[Any](serObj) === (obj))
+ assert(serInstance.deserialize[Any](byteBuffer) === (obj))
+ }
}
class ClassLoaderTestingObject
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala
new file mode 100644
index 000000000000..b9f0e873375b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.lang.{Long => JLong}
+
+import org.mockito.Mockito.when
+import org.scalatest.mockito.MockitoSugar
+
+import org.apache.spark._
+import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
+import org.apache.spark.memory._
+import org.apache.spark.unsafe.Platform
+
+class ShuffleExternalSorterSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar {
+
+ test("nested spill should be no-op") {
+ val conf = new SparkConf()
+ .setMaster("local[1]")
+ .setAppName("ShuffleExternalSorterSuite")
+ .set("spark.testing", "true")
+ .set("spark.testing.memory", "1600")
+ .set("spark.memory.fraction", "1")
+ sc = new SparkContext(conf)
+
+ val memoryManager = UnifiedMemoryManager(conf, 1)
+
+ var shouldAllocate = false
+
+ // Mock `TaskMemoryManager` to allocate free memory when `shouldAllocate` is true.
+ // This will trigger a nested spill and expose issues if we don't handle this case properly.
+ val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) {
+ override def acquireExecutionMemory(required: Long, consumer: MemoryConsumer): Long = {
+ // ExecutionMemoryPool.acquireMemory will wait until there are 400 bytes for a task to use.
+ // So we leave 400 bytes for the task.
+ if (shouldAllocate &&
+ memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed > 400) {
+ val acquireExecutionMemoryMethod =
+ memoryManager.getClass.getMethods.filter(_.getName == "acquireExecutionMemory").head
+ acquireExecutionMemoryMethod.invoke(
+ memoryManager,
+ JLong.valueOf(
+ memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed - 400),
+ JLong.valueOf(1L), // taskAttemptId
+ MemoryMode.ON_HEAP
+ ).asInstanceOf[java.lang.Long]
+ }
+ super.acquireExecutionMemory(required, consumer)
+ }
+ }
+ val taskContext = mock[TaskContext]
+ val taskMetrics = new TaskMetrics
+ when(taskContext.taskMetrics()).thenReturn(taskMetrics)
+ val sorter = new ShuffleExternalSorter(
+ taskMemoryManager,
+ sc.env.blockManager,
+ taskContext,
+ 100, // initialSize - This will require ShuffleInMemorySorter to acquire at least 800 bytes
+ 1, // numPartitions
+ conf,
+ new ShuffleWriteMetrics)
+ val inMemSorter = {
+ val field = sorter.getClass.getDeclaredField("inMemSorter")
+ field.setAccessible(true)
+ field.get(sorter).asInstanceOf[ShuffleInMemorySorter]
+ }
+ // Allocate memory to make the next "insertRecord" call triggers a spill.
+ val bytes = new Array[Byte](1)
+ while (inMemSorter.hasSpaceForAnotherRecord) {
+ sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0)
+ }
+
+ // This flag will make the mocked TaskMemoryManager acquire free memory released by spill to
+ // trigger a nested spill.
+ shouldAllocate = true
+
+ // Should throw `SparkOutOfMemoryError` as there is no enough memory: `ShuffleInMemorySorter`
+ // will try to acquire 800 bytes but there are only 400 bytes available.
+ //
+ // Before the fix, a nested spill may use a released page and this causes two tasks access the
+ // same memory page. When a task reads memory written by another task, many types of failures
+ // may happen. Here are some examples we have seen:
+ //
+ // - JVM crash. (This is easy to reproduce in the unit test as we fill newly allocated and
+ // deallocated memory with 0xa5 and 0x5a bytes which usually points to an invalid memory
+ // address)
+ // - java.lang.IllegalArgumentException: Comparison method violates its general contract!
+ // - java.lang.NullPointerException
+ // at org.apache.spark.memory.TaskMemoryManager.getPage(TaskMemoryManager.java:384)
+ // - java.lang.UnsupportedOperationException: Cannot grow BufferHolder by size -536870912
+ // because the size after growing exceeds size limitation 2147483632
+ intercept[SparkOutOfMemoryError] {
+ sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala
index 997c7de8dd02..61ed0c804e04 100644
--- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala
@@ -195,7 +195,9 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
val s1Tasks = createTasks(4, execIds)
s1Tasks.foreach { task =>
- listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, task))
+ listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId,
+ stages.head.attemptNumber,
+ task))
}
assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size)
@@ -211,55 +213,53 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
s1Tasks.foreach { task =>
check[TaskDataWrapper](task.taskId) { wrapper =>
- assert(wrapper.info.taskId === task.taskId)
+ assert(wrapper.taskId === task.taskId)
assert(wrapper.stageId === stages.head.stageId)
assert(wrapper.stageAttemptId === stages.head.attemptId)
- assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptId)))
-
- val runtime = Array[AnyRef](stages.head.stageId: JInteger, stages.head.attemptId: JInteger,
- -1L: JLong)
- assert(Arrays.equals(wrapper.runtime, runtime))
-
- assert(wrapper.info.index === task.index)
- assert(wrapper.info.attempt === task.attemptNumber)
- assert(wrapper.info.launchTime === new Date(task.launchTime))
- assert(wrapper.info.executorId === task.executorId)
- assert(wrapper.info.host === task.host)
- assert(wrapper.info.status === task.status)
- assert(wrapper.info.taskLocality === task.taskLocality.toString())
- assert(wrapper.info.speculative === task.speculative)
+ assert(wrapper.index === task.index)
+ assert(wrapper.attempt === task.attemptNumber)
+ assert(wrapper.launchTime === task.launchTime)
+ assert(wrapper.executorId === task.executorId)
+ assert(wrapper.host === task.host)
+ assert(wrapper.status === task.status)
+ assert(wrapper.taskLocality === task.taskLocality.toString())
+ assert(wrapper.speculative === task.speculative)
}
}
- // Send executor metrics update. Only update one metric to avoid a lot of boilerplate code.
- s1Tasks.foreach { task =>
- val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED),
- Some(1L), None, true, false, None)
- listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(
- task.executorId,
- Seq((task.taskId, stages.head.stageId, stages.head.attemptId, Seq(accum)))))
- }
+ // Send two executor metrics update. Only update one metric to avoid a lot of boilerplate code.
+ // The tasks are distributed among the two executors, so the executor-level metrics should
+ // hold half of the cummulative value of the metric being updated.
+ Seq(1L, 2L).foreach { value =>
+ s1Tasks.foreach { task =>
+ val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED),
+ Some(value), None, true, false, None)
+ listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(
+ task.executorId,
+ Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum)))))
+ }
- check[StageDataWrapper](key(stages.head)) { stage =>
- assert(stage.info.memoryBytesSpilled === s1Tasks.size)
- }
+ check[StageDataWrapper](key(stages.head)) { stage =>
+ assert(stage.info.memoryBytesSpilled === s1Tasks.size * value)
+ }
- val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage")
- .first(key(stages.head)).last(key(stages.head)).asScala.toSeq
- assert(execs.size > 0)
- execs.foreach { exec =>
- assert(exec.info.memoryBytesSpilled === s1Tasks.size / 2)
+ val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage")
+ .first(key(stages.head)).last(key(stages.head)).asScala.toSeq
+ assert(execs.size > 0)
+ execs.foreach { exec =>
+ assert(exec.info.memoryBytesSpilled === s1Tasks.size * value / 2)
+ }
}
// Fail one of the tasks, re-start it.
time += 1
s1Tasks.head.markFinished(TaskState.FAILED, time)
- listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId,
+ listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber,
"taskType", TaskResultLost, s1Tasks.head, null))
time += 1
val reattempt = newAttempt(s1Tasks.head, nextTaskId())
- listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId,
+ listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber,
reattempt))
assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size + 1)
@@ -275,13 +275,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
}
check[TaskDataWrapper](s1Tasks.head.taskId) { task =>
- assert(task.info.status === s1Tasks.head.status)
- assert(task.info.errorMessage == Some(TaskResultLost.toErrorString))
+ assert(task.status === s1Tasks.head.status)
+ assert(task.errorMessage == Some(TaskResultLost.toErrorString))
}
check[TaskDataWrapper](reattempt.taskId) { task =>
- assert(task.info.index === s1Tasks.head.index)
- assert(task.info.attempt === reattempt.attemptNumber)
+ assert(task.index === s1Tasks.head.index)
+ assert(task.attempt === reattempt.attemptNumber)
}
// Kill one task, restart it.
@@ -289,7 +289,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
val killed = s1Tasks.drop(1).head
killed.finishTime = time
killed.failed = true
- listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId,
+ listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber,
"taskType", TaskKilled("killed"), killed, null))
check[JobDataWrapper](1) { job =>
@@ -303,21 +303,21 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
}
check[TaskDataWrapper](killed.taskId) { task =>
- assert(task.info.index === killed.index)
- assert(task.info.errorMessage === Some("killed"))
+ assert(task.index === killed.index)
+ assert(task.errorMessage === Some("killed"))
}
// Start a new attempt and finish it with TaskCommitDenied, make sure it's handled like a kill.
time += 1
val denied = newAttempt(killed, nextTaskId())
val denyReason = TaskCommitDenied(1, 1, 1)
- listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId,
+ listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber,
denied))
time += 1
denied.finishTime = time
denied.failed = true
- listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId,
+ listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber,
"taskType", denyReason, denied, null))
check[JobDataWrapper](1) { job =>
@@ -331,13 +331,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
}
check[TaskDataWrapper](denied.taskId) { task =>
- assert(task.info.index === killed.index)
- assert(task.info.errorMessage === Some(denyReason.toErrorString))
+ assert(task.index === killed.index)
+ assert(task.errorMessage === Some(denyReason.toErrorString))
}
// Start a new attempt.
val reattempt2 = newAttempt(denied, nextTaskId())
- listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId,
+ listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber,
reattempt2))
// Succeed all tasks in stage 1.
@@ -350,7 +350,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
time += 1
pending.foreach { task =>
task.markFinished(TaskState.FINISHED, time)
- listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId,
+ listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber,
"taskType", Success, task, s1Metrics))
}
@@ -370,10 +370,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
pending.foreach { task =>
check[TaskDataWrapper](task.taskId) { wrapper =>
- assert(wrapper.info.errorMessage === None)
- assert(wrapper.info.taskMetrics.get.executorCpuTime === 2L)
- assert(wrapper.info.taskMetrics.get.executorRunTime === 4L)
- assert(wrapper.info.duration === Some(task.duration))
+ assert(wrapper.errorMessage === None)
+ assert(wrapper.executorCpuTime === 2L)
+ assert(wrapper.executorRunTime === 4L)
+ assert(wrapper.duration === task.duration)
}
}
@@ -414,13 +414,15 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
time += 1
val s2Tasks = createTasks(4, execIds)
s2Tasks.foreach { task =>
- listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, stages.last.attemptId, task))
+ listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId,
+ stages.last.attemptNumber,
+ task))
}
time += 1
s2Tasks.foreach { task =>
task.markFinished(TaskState.FAILED, time)
- listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptId,
+ listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptNumber,
"taskType", TaskResultLost, task, null))
}
@@ -455,7 +457,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
// - Re-submit stage 2, all tasks, and succeed them and the stage.
val oldS2 = stages.last
- val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptId + 1, oldS2.name, oldS2.numTasks,
+ val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptNumber + 1, oldS2.name, oldS2.numTasks,
oldS2.rddInfos, oldS2.parentIds, oldS2.details, oldS2.taskMetrics)
time += 1
@@ -466,14 +468,14 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
val newS2Tasks = createTasks(4, execIds)
newS2Tasks.foreach { task =>
- listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptId, task))
+ listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptNumber, task))
}
time += 1
newS2Tasks.foreach { task =>
task.markFinished(TaskState.FINISHED, time)
- listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptId, "taskType", Success,
- task, null))
+ listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptNumber, "taskType",
+ Success, task, null))
}
time += 1
@@ -522,14 +524,15 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
val j2s2Tasks = createTasks(4, execIds)
j2s2Tasks.foreach { task =>
- listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, j2Stages.last.attemptId,
+ listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId,
+ j2Stages.last.attemptNumber,
task))
}
time += 1
j2s2Tasks.foreach { task =>
task.markFinished(TaskState.FINISHED, time)
- listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptId,
+ listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptNumber,
"taskType", Success, task, null))
}
@@ -814,12 +817,41 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
assert(dist.memoryRemaining === maxMemory - rdd2b1.memSize - rdd1b2.memSize )
}
+ // Add block1 of rdd1 back to bm 1.
+ listener.onBlockUpdated(SparkListenerBlockUpdated(
+ BlockUpdatedInfo(bm1, rdd1b1.blockId, level, rdd1b1.memSize, rdd1b1.diskSize)))
+
+ check[ExecutorSummaryWrapper](bm1.executorId) { exec =>
+ assert(exec.info.rddBlocks === 3L)
+ assert(exec.info.memoryUsed === rdd1b1.memSize + rdd1b2.memSize + rdd2b1.memSize)
+ assert(exec.info.diskUsed === rdd1b1.diskSize + rdd1b2.diskSize + rdd2b1.diskSize)
+ }
+
// Unpersist RDD1.
listener.onUnpersistRDD(SparkListenerUnpersistRDD(rdd1b1.rddId))
intercept[NoSuchElementException] {
check[RDDStorageInfoWrapper](rdd1b1.rddId) { _ => () }
}
+ // executor1 now only contains block1 from rdd2.
+ check[ExecutorSummaryWrapper](bm1.executorId) { exec =>
+ assert(exec.info.rddBlocks === 1L)
+ assert(exec.info.memoryUsed === rdd2b1.memSize)
+ assert(exec.info.diskUsed === rdd2b1.diskSize)
+ }
+
+ // Unpersist RDD2.
+ listener.onUnpersistRDD(SparkListenerUnpersistRDD(rdd2b1.rddId))
+ intercept[NoSuchElementException] {
+ check[RDDStorageInfoWrapper](rdd2b1.rddId) { _ => () }
+ }
+
+ check[ExecutorSummaryWrapper](bm1.executorId) { exec =>
+ assert(exec.info.rddBlocks === 0L)
+ assert(exec.info.memoryUsed === 0)
+ assert(exec.info.diskUsed === 0)
+ }
+
// Update a StreamBlock.
val stream1 = StreamBlockId(1, 1L)
listener.onBlockUpdated(SparkListenerBlockUpdated(
@@ -843,6 +875,24 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
intercept[NoSuchElementException] {
check[StreamBlockData](stream1.name) { _ => () }
}
+
+ // Update a BroadcastBlock.
+ val broadcast1 = BroadcastBlockId(1L)
+ listener.onBlockUpdated(SparkListenerBlockUpdated(
+ BlockUpdatedInfo(bm1, broadcast1, level, 1L, 1L)))
+
+ check[ExecutorSummaryWrapper](bm1.executorId) { exec =>
+ assert(exec.info.memoryUsed === 1L)
+ assert(exec.info.diskUsed === 1L)
+ }
+
+ // Drop a BroadcastBlock.
+ listener.onBlockUpdated(SparkListenerBlockUpdated(
+ BlockUpdatedInfo(bm1, broadcast1, StorageLevel.NONE, 1L, 1L)))
+ check[ExecutorSummaryWrapper](bm1.executorId) { exec =>
+ assert(exec.info.memoryUsed === 0)
+ assert(exec.info.diskUsed === 0)
+ }
}
test("eviction of old data") {
@@ -888,6 +938,27 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
assert(store.count(classOf[StageDataWrapper]) === 3)
assert(store.count(classOf[RDDOperationGraphWrapper]) === 3)
+ val dropped = stages.drop(1).head
+
+ // Cache some quantiles by calling AppStatusStore.taskSummary(). For quantiles to be
+ // calculated, we need at least one finished task. The code in AppStatusStore uses
+ // `executorRunTime` to detect valid tasks, so that metric needs to be updated in the
+ // task end event.
+ time += 1
+ val task = createTasks(1, Array("1")).head
+ listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task))
+
+ time += 1
+ task.markFinished(TaskState.FINISHED, time)
+ val metrics = TaskMetrics.empty
+ metrics.setExecutorRunTime(42L)
+ listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId,
+ "taskType", Success, task, metrics))
+
+ new AppStatusStore(store)
+ .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d))
+ assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 3)
+
stages.drop(1).foreach { s =>
time += 1
s.completionTime = Some(time)
@@ -899,6 +970,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
intercept[NoSuchElementException] {
store.read(classOf[StageDataWrapper], Array(2, 0))
}
+ assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 0)
val attempt2 = new StageInfo(3, 1, "stage3", 4, Nil, Nil, "details3")
time += 1
@@ -919,13 +991,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
time += 1
val tasks = createTasks(2, Array("1"))
tasks.foreach { task =>
- listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task))
+ listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task))
}
assert(store.count(classOf[TaskDataWrapper]) === 2)
// Start a 3rd task. The finished tasks should be deleted.
createTasks(1, Array("1")).foreach { task =>
- listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task))
+ listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task))
}
assert(store.count(classOf[TaskDataWrapper]) === 2)
intercept[NoSuchElementException] {
@@ -934,7 +1006,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
// Start a 4th task. The first task should be deleted, even if it's still running.
createTasks(1, Array("1")).foreach { task =>
- listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task))
+ listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task))
}
assert(store.count(classOf[TaskDataWrapper]) === 2)
intercept[NoSuchElementException] {
@@ -942,6 +1014,220 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
}
}
+ test("eviction should respect job completion time") {
+ val testConf = conf.clone().set(MAX_RETAINED_JOBS, 2)
+ val listener = new AppStatusListener(store, testConf, true)
+
+ // Start job 1 and job 2
+ time += 1
+ listener.onJobStart(SparkListenerJobStart(1, time, Nil, null))
+ time += 1
+ listener.onJobStart(SparkListenerJobStart(2, time, Nil, null))
+
+ // Stop job 2 before job 1
+ time += 1
+ listener.onJobEnd(SparkListenerJobEnd(2, time, JobSucceeded))
+ time += 1
+ listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded))
+
+ // Start job 3 and job 2 should be evicted.
+ time += 1
+ listener.onJobStart(SparkListenerJobStart(3, time, Nil, null))
+ assert(store.count(classOf[JobDataWrapper]) === 2)
+ intercept[NoSuchElementException] {
+ store.read(classOf[JobDataWrapper], 2)
+ }
+ }
+
+ test("eviction should respect stage completion time") {
+ val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2)
+ val listener = new AppStatusListener(store, testConf, true)
+
+ val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1")
+ val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2")
+ val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3")
+
+ // Start stage 1 and stage 2
+ time += 1
+ stage1.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties()))
+ time += 1
+ stage2.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties()))
+
+ // Stop stage 2 before stage 1
+ time += 1
+ stage2.completionTime = Some(time)
+ listener.onStageCompleted(SparkListenerStageCompleted(stage2))
+ time += 1
+ stage1.completionTime = Some(time)
+ listener.onStageCompleted(SparkListenerStageCompleted(stage1))
+
+ // Start stage 3 and stage 2 should be evicted.
+ stage3.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties()))
+ assert(store.count(classOf[StageDataWrapper]) === 2)
+ intercept[NoSuchElementException] {
+ store.read(classOf[StageDataWrapper], Array(2, 0))
+ }
+ }
+
+ test("skipped stages should be evicted before completed stages") {
+ val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2)
+ val listener = new AppStatusListener(store, testConf, true)
+
+ val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1")
+ val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2")
+
+ // Sart job 1
+ time += 1
+ listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage1, stage2), null))
+
+ // Start and stop stage 1
+ time += 1
+ stage1.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties()))
+
+ time += 1
+ stage1.completionTime = Some(time)
+ listener.onStageCompleted(SparkListenerStageCompleted(stage1))
+
+ // Stop job 1 and stage 2 will become SKIPPED
+ time += 1
+ listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded))
+
+ // Submit stage 3 and verify stage 2 is evicted
+ val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3")
+ time += 1
+ stage3.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties()))
+
+ assert(store.count(classOf[StageDataWrapper]) === 2)
+ intercept[NoSuchElementException] {
+ store.read(classOf[StageDataWrapper], Array(2, 0))
+ }
+ }
+
+ test("eviction should respect task completion time") {
+ val testConf = conf.clone().set(MAX_RETAINED_TASKS_PER_STAGE, 2)
+ val listener = new AppStatusListener(store, testConf, true)
+
+ val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1")
+ stage1.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties()))
+
+ // Start task 1 and task 2
+ val tasks = createTasks(3, Array("1"))
+ tasks.take(2).foreach { task =>
+ listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, task))
+ }
+
+ // Stop task 2 before task 1
+ time += 1
+ tasks(1).markFinished(TaskState.FINISHED, time)
+ listener.onTaskEnd(
+ SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(1), null))
+ time += 1
+ tasks(0).markFinished(TaskState.FINISHED, time)
+ listener.onTaskEnd(
+ SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null))
+
+ // Start task 3 and task 2 should be evicted.
+ listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, tasks(2)))
+ assert(store.count(classOf[TaskDataWrapper]) === 2)
+ intercept[NoSuchElementException] {
+ store.read(classOf[TaskDataWrapper], tasks(1).id)
+ }
+ }
+
+ test("lastStageAttempt should fail when the stage doesn't exist") {
+ val testConf = conf.clone().set(MAX_RETAINED_STAGES, 1)
+ val listener = new AppStatusListener(store, testConf, true)
+ val appStore = new AppStatusStore(store)
+
+ val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1")
+ val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2")
+ val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3")
+
+ time += 1
+ stage1.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties()))
+ stage1.completionTime = Some(time)
+ listener.onStageCompleted(SparkListenerStageCompleted(stage1))
+
+ // Make stage 3 complete before stage 2 so that stage 3 will be evicted
+ time += 1
+ stage3.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties()))
+ stage3.completionTime = Some(time)
+ listener.onStageCompleted(SparkListenerStageCompleted(stage3))
+
+ time += 1
+ stage2.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties()))
+ stage2.completionTime = Some(time)
+ listener.onStageCompleted(SparkListenerStageCompleted(stage2))
+
+ assert(appStore.asOption(appStore.lastStageAttempt(1)) === None)
+ assert(appStore.asOption(appStore.lastStageAttempt(2)).map(_.stageId) === Some(2))
+ assert(appStore.asOption(appStore.lastStageAttempt(3)) === None)
+ }
+
+ test("SPARK-24415: update metrics for tasks that finish late") {
+ val listener = new AppStatusListener(store, conf, true)
+
+ val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1")
+ val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2")
+
+ // Start job
+ listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage1, stage2), null))
+
+ // Start 2 stages
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties()))
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties()))
+
+ // Start 2 Tasks
+ val tasks = createTasks(2, Array("1"))
+ tasks.foreach { task =>
+ listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, task))
+ }
+
+ // Task 1 Finished
+ time += 1
+ tasks(0).markFinished(TaskState.FINISHED, time)
+ listener.onTaskEnd(
+ SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null))
+
+ // Stage 1 Completed
+ stage1.failureReason = Some("Failed")
+ listener.onStageCompleted(SparkListenerStageCompleted(stage1))
+
+ // Stop job 1
+ time += 1
+ listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded))
+
+ // Task 2 Killed
+ time += 1
+ tasks(1).markFinished(TaskState.FINISHED, time)
+ listener.onTaskEnd(
+ SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType",
+ TaskKilled(reason = "Killed"), tasks(1), null))
+
+ // Ensure killed task metrics are updated
+ val allStages = store.view(classOf[StageDataWrapper]).reverse().asScala.map(_.info)
+ val failedStages = allStages.filter(_.status == v1.StageStatus.FAILED)
+ assert(failedStages.size == 1)
+ assert(failedStages.head.numKilledTasks == 1)
+ assert(failedStages.head.numCompleteTasks == 1)
+
+ val allJobs = store.view(classOf[JobDataWrapper]).reverse().asScala.map(_.info)
+ assert(allJobs.size == 1)
+ assert(allJobs.head.numKilledTasks == 1)
+ assert(allJobs.head.numCompletedTasks == 1)
+ assert(allJobs.head.numActiveStages == 1)
+ assert(allJobs.head.numFailedStages == 1)
+ }
+
test("driver logs") {
val listener = new AppStatusListener(store, conf, true)
@@ -960,7 +1246,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
}
}
- private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId)
+ private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptNumber)
private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = {
val value = store.read(classTag[T].runtimeClass, key).asInstanceOf[T]
diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala
new file mode 100644
index 000000000000..92f90f3d96dd
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.status
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.status.api.v1.TaskMetricDistributions
+import org.apache.spark.util.Distribution
+import org.apache.spark.util.kvstore._
+
+class AppStatusStoreSuite extends SparkFunSuite {
+
+ private val uiQuantiles = Array(0.0, 0.25, 0.5, 0.75, 1.0)
+ private val stageId = 1
+ private val attemptId = 1
+
+ test("quantile calculation: 1 task") {
+ compareQuantiles(1, uiQuantiles)
+ }
+
+ test("quantile calculation: few tasks") {
+ compareQuantiles(4, uiQuantiles)
+ }
+
+ test("quantile calculation: more tasks") {
+ compareQuantiles(100, uiQuantiles)
+ }
+
+ test("quantile calculation: lots of tasks") {
+ compareQuantiles(4096, uiQuantiles)
+ }
+
+ test("quantile calculation: custom quantiles") {
+ compareQuantiles(4096, Array(0.01, 0.33, 0.5, 0.42, 0.69, 0.99))
+ }
+
+ test("quantile cache") {
+ val store = new InMemoryStore()
+ (0 until 4096).foreach { i => store.write(newTaskData(i)) }
+
+ val appStore = new AppStatusStore(store)
+
+ appStore.taskSummary(stageId, attemptId, Array(0.13d))
+ intercept[NoSuchElementException] {
+ store.read(classOf[CachedQuantile], Array(stageId, attemptId, "13"))
+ }
+
+ appStore.taskSummary(stageId, attemptId, Array(0.25d))
+ val d1 = store.read(classOf[CachedQuantile], Array(stageId, attemptId, "25"))
+
+ // Add a new task to force the cached quantile to be evicted, and make sure it's updated.
+ store.write(newTaskData(4096))
+ appStore.taskSummary(stageId, attemptId, Array(0.25d, 0.50d, 0.73d))
+
+ val d2 = store.read(classOf[CachedQuantile], Array(stageId, attemptId, "25"))
+ assert(d1.taskCount != d2.taskCount)
+
+ store.read(classOf[CachedQuantile], Array(stageId, attemptId, "50"))
+ intercept[NoSuchElementException] {
+ store.read(classOf[CachedQuantile], Array(stageId, attemptId, "73"))
+ }
+
+ assert(store.count(classOf[CachedQuantile]) === 2)
+ }
+
+ private def compareQuantiles(count: Int, quantiles: Array[Double]): Unit = {
+ val store = new InMemoryStore()
+ val values = (0 until count).map { i =>
+ val task = newTaskData(i)
+ store.write(task)
+ i.toDouble
+ }.toArray
+
+ val summary = new AppStatusStore(store).taskSummary(stageId, attemptId, quantiles).get
+ val dist = new Distribution(values, 0, values.length).getQuantiles(quantiles.sorted)
+
+ dist.zip(summary.executorRunTime).foreach { case (expected, actual) =>
+ assert(expected === actual)
+ }
+ }
+
+ private def newTaskData(i: Int): TaskDataWrapper = {
+ new TaskDataWrapper(
+ i, i, i, i, i, i, i.toString, i.toString, i.toString, i.toString, false, Nil, None,
+ i, i, i, i, i, i, i, i, i, i,
+ i, i, i, i, i, i, i, i, i, i,
+ i, i, i, i, stageId, attemptId)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala
new file mode 100644
index 000000000000..9e74e86ad54b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.status
+
+import java.util.Date
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.status.api.v1.{TaskData, TaskMetrics}
+
+class AppStatusUtilsSuite extends SparkFunSuite {
+
+ test("schedulerDelay") {
+ val runningTask = new TaskData(
+ taskId = 0,
+ index = 0,
+ attempt = 0,
+ launchTime = new Date(1L),
+ resultFetchStart = None,
+ duration = Some(100L),
+ executorId = "1",
+ host = "localhost",
+ status = "RUNNING",
+ taskLocality = "PROCESS_LOCAL",
+ speculative = false,
+ accumulatorUpdates = Nil,
+ errorMessage = None,
+ taskMetrics = Some(new TaskMetrics(
+ executorDeserializeTime = 0L,
+ executorDeserializeCpuTime = 0L,
+ executorRunTime = 0L,
+ executorCpuTime = 0L,
+ resultSize = 0L,
+ jvmGcTime = 0L,
+ resultSerializationTime = 0L,
+ memoryBytesSpilled = 0L,
+ diskBytesSpilled = 0L,
+ peakExecutionMemory = 0L,
+ inputMetrics = null,
+ outputMetrics = null,
+ shuffleReadMetrics = null,
+ shuffleWriteMetrics = null)))
+ assert(AppStatusUtils.schedulerDelay(runningTask) === 0L)
+
+ val finishedTask = new TaskData(
+ taskId = 0,
+ index = 0,
+ attempt = 0,
+ launchTime = new Date(1L),
+ resultFetchStart = None,
+ duration = Some(100L),
+ executorId = "1",
+ host = "localhost",
+ status = "SUCCESS",
+ taskLocality = "PROCESS_LOCAL",
+ speculative = false,
+ accumulatorUpdates = Nil,
+ errorMessage = None,
+ taskMetrics = Some(new TaskMetrics(
+ executorDeserializeTime = 5L,
+ executorDeserializeCpuTime = 3L,
+ executorRunTime = 90L,
+ executorCpuTime = 10L,
+ resultSize = 100L,
+ jvmGcTime = 10L,
+ resultSerializationTime = 2L,
+ memoryBytesSpilled = 0L,
+ diskBytesSpilled = 0L,
+ peakExecutionMemory = 100L,
+ inputMetrics = null,
+ outputMetrics = null,
+ shuffleReadMetrics = null,
+ shuffleWriteMetrics = null)))
+ assert(AppStatusUtils.schedulerDelay(finishedTask) === 3L)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
index 917db766f7f1..9c0699bc981f 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
@@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
try {
TaskContext.setTaskContext(
- new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null))
+ new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null))
block
} finally {
TaskContext.unset()
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 629eed49b04c..4d2168f0b338 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -44,7 +44,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf}
import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager}
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor}
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.LiveListenerBus
@@ -1403,7 +1403,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService {
var numCalls = 0
- var tempFileManager: TempFileManager = null
+ var tempFileManager: DownloadFileManager = null
override def init(blockDataManager: BlockDataManager): Unit = {}
@@ -1413,7 +1413,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener,
- tempFileManager: TempFileManager): Unit = {
+ tempFileManager: DownloadFileManager): Unit = {
listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1)))
}
@@ -1440,7 +1440,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
port: Int,
execId: String,
blockId: String,
- tempFileManager: TempFileManager): ManagedBuffer = {
+ tempFileManager: DownloadFileManager): ManagedBuffer = {
numCalls += 1
this.tempFileManager = tempFileManager
if (numCalls <= maxFailures) {
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 5bfe9905ff17..24244f9657fb 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester
import org.apache.spark.{SparkFunSuite, TaskContext}
import org.apache.spark.network._
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager}
import org.apache.spark.network.util.LimitedInputStream
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils
@@ -352,6 +352,51 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
intercept[FetchFailedException] { iterator.next() }
}
+ test("big blocks are not checked for corruption") {
+ val corruptStream = mock(classOf[InputStream])
+ when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
+ val corruptBuffer = mock(classOf[ManagedBuffer])
+ when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
+ doReturn(10000L).when(corruptBuffer).size()
+
+ val blockManager = mock(classOf[BlockManager])
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+ doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0))
+ val localBlockLengths = Seq[Tuple2[BlockId, Long]](
+ ShuffleBlockId(0, 0, 0) -> corruptBuffer.size()
+ )
+
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val remoteBlockLengths = Seq[Tuple2[BlockId, Long]](
+ ShuffleBlockId(0, 1, 0) -> corruptBuffer.size()
+ )
+
+ val transfer = createMockTransfer(
+ Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer, ShuffleBlockId(0, 1, 0) -> corruptBuffer))
+
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (localBmId, localBlockLengths),
+ (remoteBmId, remoteBlockLengths)
+ )
+
+ val taskContext = TaskContext.empty()
+ val iterator = new ShuffleBlockFetcherIterator(
+ taskContext,
+ transfer,
+ blockManager,
+ blocksByAddress,
+ (_, in) => new LimitedInputStream(in, 10000),
+ 2048,
+ Int.MaxValue,
+ Int.MaxValue,
+ Int.MaxValue,
+ true)
+ // Blocks should be returned without exceptions.
+ assert(Set(iterator.next()._1, iterator.next()._1) ===
+ Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0)))
+ }
+
test("retry corrupt blocks (disabled)") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
@@ -437,12 +482,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val remoteBlocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer())
val transfer = mock(classOf[BlockTransferService])
- var tempFileManager: TempFileManager = null
+ var tempFileManager: DownloadFileManager = null
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
- tempFileManager = invocation.getArguments()(5).asInstanceOf[TempFileManager]
+ tempFileManager = invocation.getArguments()(5).asInstanceOf[DownloadFileManager]
Future {
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0)))
diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
index 661d0d48d2f3..6044563f7dde 100644
--- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
@@ -28,22 +28,82 @@ import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.status.AppStatusStore
-import org.apache.spark.ui.jobs.{StagePage, StagesTab}
+import org.apache.spark.status.api.v1.{AccumulableInfo => UIAccumulableInfo, StageData, StageStatus}
+import org.apache.spark.status.config._
+import org.apache.spark.ui.jobs.{ApiHelper, StagePage, StagesTab, TaskPagedTable}
class StagePageSuite extends SparkFunSuite with LocalSparkContext {
private val peakExecutionMemory = 10
+ test("ApiHelper.COLUMN_TO_INDEX should match headers of the task table") {
+ val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L)
+ val statusStore = AppStatusStore.createLiveStore(conf)
+ try {
+ val stageData = new StageData(
+ status = StageStatus.ACTIVE,
+ stageId = 1,
+ attemptId = 1,
+ numTasks = 1,
+ numActiveTasks = 1,
+ numCompleteTasks = 1,
+ numFailedTasks = 1,
+ numKilledTasks = 1,
+ numCompletedIndices = 1,
+
+ executorRunTime = 1L,
+ executorCpuTime = 1L,
+ submissionTime = None,
+ firstTaskLaunchedTime = None,
+ completionTime = None,
+ failureReason = None,
+
+ inputBytes = 1L,
+ inputRecords = 1L,
+ outputBytes = 1L,
+ outputRecords = 1L,
+ shuffleReadBytes = 1L,
+ shuffleReadRecords = 1L,
+ shuffleWriteBytes = 1L,
+ shuffleWriteRecords = 1L,
+ memoryBytesSpilled = 1L,
+ diskBytesSpilled = 1L,
+
+ name = "stage1",
+ description = Some("description"),
+ details = "detail",
+ schedulingPool = "pool1",
+
+ rddIds = Seq(1),
+ accumulatorUpdates = Seq(new UIAccumulableInfo(0L, "acc", None, "value")),
+ tasks = None,
+ executorSummary = None,
+ killedTasksSummary = Map.empty
+ )
+ val taskTable = new TaskPagedTable(
+ stageData,
+ basePath = "/a/b/c",
+ currentTime = 0,
+ pageSize = 10,
+ sortColumn = "Index",
+ desc = false,
+ store = statusStore
+ )
+ val columnNames = (taskTable.headers \ "th" \ "a").map(_.child(1).text).toSet
+ assert(columnNames === ApiHelper.COLUMN_TO_INDEX.keySet)
+ } finally {
+ statusStore.close()
+ }
+ }
+
test("peak execution memory should displayed") {
- val conf = new SparkConf(false)
- val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT)
+ val html = renderStagePage().toString().toLowerCase(Locale.ROOT)
val targetString = "peak execution memory"
assert(html.contains(targetString))
}
test("SPARK-10543: peak execution memory should be per-task rather than cumulative") {
- val conf = new SparkConf(false)
- val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT)
+ val html = renderStagePage().toString().toLowerCase(Locale.ROOT)
// verify min/25/50/75/max show task value not cumulative values
assert(html.contains(s"$peakExecutionMemory.0 b | " * 5))
}
@@ -52,7 +112,8 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext {
* Render a stage page started with the given conf and return the HTML.
* This also runs a dummy stage to populate the page with useful content.
*/
- private def renderStagePage(conf: SparkConf): Seq[Node] = {
+ private def renderStagePage(): Seq[Node] = {
+ val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L)
val statusStore = AppStatusStore.createLiveStore(conf)
val listener = statusStore.listener.get
diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
index 326546787ab6..0f20eea73504 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
@@ -706,6 +706,23 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
}
}
+ test("stages page should show skipped stages") {
+ withSpark(newSparkContext()) { sc =>
+ val rdd = sc.parallelize(0 to 100, 100).repartition(10).cache()
+ rdd.count()
+ rdd.count()
+
+ eventually(timeout(5 seconds), interval(50 milliseconds)) {
+ goToUi(sc, "/stages")
+ find(id("skipped")).get.text should be("Skipped Stages (1)")
+ }
+ val stagesJson = getJson(sc.ui.get, "stages")
+ stagesJson.children.size should be (4)
+ val stagesStatus = stagesJson.children.map(_ \ "status")
+ stagesStatus.count(_ == JString(StageStatus.SKIPPED.name())) should be (1)
+ }
+ }
+
def goToUi(sc: SparkContext, path: String): Unit = {
goToUi(sc.ui.get, path)
}
diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
index a04644d57ed8..fe0a9a471a65 100644
--- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
+++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import org.apache.spark._
+import org.apache.spark.serializer.JavaSerializer
class AccumulatorV2Suite extends SparkFunSuite {
@@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite {
assert(acc3.isZero)
assert(acc3.value === "")
}
+
+ test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") {
+ class MyData(val i: Int) extends Serializable
+ val param = new AccumulatorParam[MyData] {
+ override def zero(initialValue: MyData): MyData = new MyData(0)
+ override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i)
+ }
+
+ val acc = new LegacyAccumulatorWrapper(new MyData(0), param)
+ acc.metadata = AccumulatorMetadata(
+ AccumulatorContext.newId(),
+ Some("test"),
+ countFailedValues = false)
+ AccumulatorContext.register(acc)
+
+ val ser = new JavaSerializer(new SparkConf).newInstance()
+ ser.serialize(acc)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index eaea6b030c15..cde250ca6566 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -1167,6 +1167,22 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port")
}
}
+
+ object MalformedClassObject {
+ class MalformedClass
+ }
+
+ test("Safe getSimpleName") {
+ // getSimpleName on class of MalformedClass will result in error: Malformed class name
+ // Utils.getSimpleName works
+ val err = intercept[java.lang.InternalError] {
+ classOf[MalformedClassObject.MalformedClass].getSimpleName
+ }
+ assert(err.getMessage === "Malformed class name")
+
+ assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) ===
+ "UtilsSuite$MalformedClassObject$MalformedClass")
+ }
}
private class SimpleExtension
diff --git a/data/mllib/images/multi-channel/BGRA_alpha_60.png b/data/mllib/images/multi-channel/BGRA_alpha_60.png
new file mode 100644
index 000000000000..913637cd2828
Binary files /dev/null and b/data/mllib/images/multi-channel/BGRA_alpha_60.png differ
diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml
index bbda824dd13b..945686de4996 100644
--- a/dev/checkstyle-suppressions.xml
+++ b/dev/checkstyle-suppressions.xml
@@ -17,7 +17,7 @@
+"https://checkstyle.org/dtds/suppressions_1_1.dtd">
-Once the edges have be partitioned the key challenge to efficient graph-parallel computation is
+Once the edges have been partitioned the key challenge to efficient graph-parallel computation is
efficiently joining vertex attributes with the edges. Because real-world graphs typically have more
edges than vertices, we move vertex attributes to the edges. Because not all partitions will
contain edges adjacent to all vertices we internally maintain a routing table which identifies where
diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md
index bf979f3c73a5..ddd2f4b49ca0 100644
--- a/docs/ml-classification-regression.md
+++ b/docs/ml-classification-regression.md
@@ -87,7 +87,7 @@ More details on parameters can be found in the [R API documentation](api/R/spark
The `spark.ml` implementation of logistic regression also supports
extracting a summary of the model over the training set. Note that the
predictions and metrics which are stored as `DataFrame` in
-`BinaryLogisticRegressionSummary` are annotated `@transient` and hence
+`LogisticRegressionSummary` are annotated `@transient` and hence
only available on the driver.
@@ -97,10 +97,9 @@ only available on the driver.
[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary)
provides a summary for a
[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel).
-Currently, only binary classification is supported and the
-summary must be explicitly cast to
-[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
-This will likely change when multiclass classification is supported.
+In the case of binary classification, certain additional metrics are
+available, e.g. ROC curve. The binary summary can be accessed via the
+`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
Continuing the earlier example:
@@ -111,10 +110,9 @@ Continuing the earlier example:
[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html)
provides a summary for a
[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html).
-Currently, only binary classification is supported and the
-summary must be explicitly cast to
-[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).
-Support for multiclass model summaries will be added in the future.
+In the case of binary classification, certain additional metrics are
+available, e.g. ROC curve. The binary summary can be accessed via the
+`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).
Continuing the earlier example:
@@ -125,7 +123,8 @@ Continuing the earlier example:
[`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary)
provides a summary for a
[`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel).
-Currently, only binary classification is supported. Support for multiclass model summaries will be added in the future.
+In the case of binary classification, certain additional metrics are
+available, e.g. ROC curve. See [`BinaryLogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary).
Continuing the earlier example:
@@ -162,7 +161,8 @@ For a detailed derivation please see [here](https://en.wikipedia.org/wiki/Multin
**Examples**
The following example shows how to train a multiclass logistic regression
-model with elastic net regularization.
+model with elastic net regularization, as well as extract the multiclass
+training summary for evaluating the model.
diff --git a/docs/ml-features.md b/docs/ml-features.md
index 72643137d96b..3370eb389327 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -222,9 +222,9 @@ The `FeatureHasher` transformer operates on multiple columns. Each column may co
numeric or categorical features. Behavior and handling of column data types is as follows:
- Numeric columns: For numeric features, the hash value of the column name is used to map the
-feature value to its index in the feature vector. Numeric features are never treated as
-categorical, even when they are integers. You must explicitly convert numeric columns containing
-categorical features to strings first.
+feature value to its index in the feature vector. By default, numeric features are not treated
+as categorical (even when they are integers). To treat them as categorical, specify the relevant
+columns using the `categoricalCols` parameter.
- String columns: For categorical features, the hash value of the string "column_name=value"
is used to map to the vector index, with an indicator value of `1.0`. Thus, categorical features
are "one-hot" encoded (similarly to using [OneHotEncoder](ml-features.html#onehotencoder) with
@@ -775,35 +775,43 @@ for more details on the API.
-## OneHotEncoder
+## OneHotEncoder (Deprecated since 2.3.0)
-[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features.
+Because this existing `OneHotEncoder` is a stateless transformer, it is not usable on new data where the number of categories may differ from the training data. In order to fix this, a new `OneHotEncoderEstimator` was created that produces an `OneHotEncoderModel` when fitting. For more detail, please see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030).
+
+`OneHotEncoder` has been deprecated in 2.3.0 and will be removed in 3.0.0. Please use [OneHotEncoderEstimator](ml-features.html#onehotencoderestimator) instead.
+
+## OneHotEncoderEstimator
+
+[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a categorical feature, represented as a label index, to a binary vector with at most a single one-value indicating the presence of a specific feature value from among the set of all feature values. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first.
+
+`OneHotEncoderEstimator` can transform multiple columns, returning an one-hot-encoded output vector column for each input column. It is common to merge these vectors into a single feature vector using [VectorAssembler](ml-features.html#vectorassembler).
+
+`OneHotEncoderEstimator` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error).
**Examples**
-Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder)
-for more details on the API.
+Refer to the [OneHotEncoderEstimator Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoderEstimator) for more details on the API.
-{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %}
+{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala %}
-Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html)
+Refer to the [OneHotEncoderEstimator Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoderEstimator.html)
for more details on the API.
-{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %}
+{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java %}
-Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder)
-for more details on the API.
+Refer to the [OneHotEncoderEstimator Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoderEstimator) for more details on the API.
-{% include_example python/ml/onehot_encoder_example.py %}
+{% include_example python/ml/onehot_encoder_estimator_example.py %}
@@ -1283,6 +1291,57 @@ for more details on the API.
+## VectorSizeHint
+
+It can sometimes be useful to explicitly specify the size of the vectors for a column of
+`VectorType`. For example, `VectorAssembler` uses size information from its input columns to
+produce size information and metadata for its output column. While in some cases this information
+can be obtained by inspecting the contents of the column, in a streaming dataframe the contents are
+not available until the stream is started. `VectorSizeHint` allows a user to explicitly specify the
+vector size for a column so that `VectorAssembler`, or other transformers that might
+need to know vector size, can use that column as an input.
+
+To use `VectorSizeHint` a user must set the `inputCol` and `size` parameters. Applying this
+transformer to a dataframe produces a new dataframe with updated metadata for `inputCol` specifying
+the vector size. Downstream operations on the resulting dataframe can get this size using the
+meatadata.
+
+`VectorSizeHint` can also take an optional `handleInvalid` parameter which controls its
+behaviour when the vector column contains nulls or vectors of the wrong size. By default
+`handleInvalid` is set to "error", indicating an exception should be thrown. This parameter can
+also be set to "skip", indicating that rows containing invalid values should be filtered out from
+the resulting dataframe, or "optimistic", indicating that the column should not be checked for
+invalid values and all rows should be kept. Note that the use of "optimistic" can cause the
+resulting dataframe to be in an inconsistent state, me:aning the metadata for the column
+`VectorSizeHint` was applied to does not match the contents of that column. Users should take care
+to avoid this kind of inconsistent state.
+
+
+
+
+Refer to the [VectorSizeHint Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSizeHint)
+for more details on the API.
+
+{% include_example scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala %}
+
+
+
+
+Refer to the [VectorSizeHint Java docs](api/java/org/apache/spark/ml/feature/VectorSizeHint.html)
+for more details on the API.
+
+{% include_example java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java %}
+
+
+
+
+Refer to the [VectorSizeHint Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorSizeHint)
+for more details on the API.
+
+{% include_example python/ml/vector_size_hint_example.py %}
+
+
+
## QuantileDiscretizer
`QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index f6288e7c32d9..aea07be34cb8 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -72,32 +72,31 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4
[^1]: To learn more about the benefits and background of system optimised natives, you may wish to
watch Sam Halliday's ScalaX talk on [High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/).
-# Highlights in 2.2
+# Highlights in 2.3
-The list below highlights some of the new features and enhancements added to MLlib in the `2.2`
+The list below highlights some of the new features and enhancements added to MLlib in the `2.3`
release of Spark:
-* [`ALS`](ml-collaborative-filtering.html) methods for _top-k_ recommendations for all
- users or items, matching the functionality in `mllib`
- ([SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)).
- Performance was also improved for both `ml` and `mllib`
- ([SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968) and
- [SPARK-20587](https://issues.apache.org/jira/browse/SPARK-20587))
-* [`Correlation`](ml-statistics.html#correlation) and
- [`ChiSquareTest`](ml-statistics.html#hypothesis-testing) stats functions for `DataFrames`
- ([SPARK-19636](https://issues.apache.org/jira/browse/SPARK-19636) and
- [SPARK-19635](https://issues.apache.org/jira/browse/SPARK-19635))
-* [`FPGrowth`](ml-frequent-pattern-mining.html#fp-growth) algorithm for frequent pattern mining
- ([SPARK-14503](https://issues.apache.org/jira/browse/SPARK-14503))
-* `GLM` now supports the full `Tweedie` family
- ([SPARK-18929](https://issues.apache.org/jira/browse/SPARK-18929))
-* [`Imputer`](ml-features.html#imputer) feature transformer to impute missing values in a dataset
- ([SPARK-13568](https://issues.apache.org/jira/browse/SPARK-13568))
-* [`LinearSVC`](ml-classification-regression.html#linear-support-vector-machine)
- for linear Support Vector Machine classification
- ([SPARK-14709](https://issues.apache.org/jira/browse/SPARK-14709))
-* Logistic regression now supports constraints on the coefficients during training
- ([SPARK-20047](https://issues.apache.org/jira/browse/SPARK-20047))
+* Built-in support for reading images into a `DataFrame` was added
+([SPARK-21866](https://issues.apache.org/jira/browse/SPARK-21866)).
+* [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator) was added, and should be
+used instead of the existing `OneHotEncoder` transformer. The new estimator supports
+transforming multiple columns.
+* Multiple column support was also added to `QuantileDiscretizer` and `Bucketizer`
+([SPARK-22397](https://issues.apache.org/jira/browse/SPARK-22397) and
+[SPARK-20542](https://issues.apache.org/jira/browse/SPARK-20542))
+* A new [`FeatureHasher`](ml-features.html#featurehasher) transformer was added
+ ([SPARK-13969](https://issues.apache.org/jira/browse/SPARK-13969)).
+* Added support for evaluating multiple models in parallel when performing cross-validation using
+[`TrainValidationSplit` or `CrossValidator`](ml-tuning.html)
+([SPARK-19357](https://issues.apache.org/jira/browse/SPARK-19357)).
+* Improved support for custom pipeline components in Python (see
+[SPARK-21633](https://issues.apache.org/jira/browse/SPARK-21633) and
+[SPARK-21542](https://issues.apache.org/jira/browse/SPARK-21542)).
+* `DataFrame` functions for descriptive summary statistics over vector columns
+([SPARK-19634](https://issues.apache.org/jira/browse/SPARK-19634)).
+* Robust linear regression with Huber loss
+([SPARK-3181](https://issues.apache.org/jira/browse/SPARK-3181)).
# Migration guide
@@ -109,42 +108,40 @@ and the migration guide below will explain all changes between releases.
### Breaking changes
-There are no breaking changes.
+* The class and trait hierarchy for logistic regression model summaries was changed to be cleaner
+and better accommodate the addition of the multi-class summary. This is a breaking change for user
+code that casts a `LogisticRegressionTrainingSummary` to a
+`BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary`
+method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail
+(_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which
+will still work correctly for both multinomial and binary cases.
### Deprecations and changes of behavior
**Deprecations**
-There are no deprecations.
+* `OneHotEncoder` has been deprecated and will be removed in `3.0`. It has been replaced by the
+new [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator)
+(see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030)). **Note** that
+`OneHotEncoderEstimator` will be renamed to `OneHotEncoder` in `3.0` (but
+`OneHotEncoderEstimator` will be kept as an alias).
**Changes of behavior**
* [SPARK-21027](https://issues.apache.org/jira/browse/SPARK-21027):
- We are now setting the default parallelism used in `OneVsRest` to be 1 (i.e. serial), in 2.2 and earlier version,
- the `OneVsRest` parallelism would be parallelism of the default threadpool in scala.
-
-## From 2.1 to 2.2
-
-### Breaking changes
-
-There are no breaking changes.
-
-### Deprecations and changes of behavior
-
-**Deprecations**
-
-There are no deprecations.
-
-**Changes of behavior**
-
-* [SPARK-19787](https://issues.apache.org/jira/browse/SPARK-19787):
- Default value of `regParam` changed from `1.0` to `0.1` for `ALS.train` method (marked `DeveloperApi`).
- **Note** this does _not affect_ the `ALS` Estimator or Model, nor MLlib's `ALS` class.
-* [SPARK-14772](https://issues.apache.org/jira/browse/SPARK-14772):
- Fixed inconsistency between Python and Scala APIs for `Param.copy` method.
-* [SPARK-11569](https://issues.apache.org/jira/browse/SPARK-11569):
- `StringIndexer` now handles `NULL` values in the same way as unseen values. Previously an exception
- would always be thrown regardless of the setting of the `handleInvalid` parameter.
+ The default parallelism used in `OneVsRest` is now set to 1 (i.e. serial). In `2.2` and
+ earlier versions, the level of parallelism was set to the default threadpool size in Scala.
+* [SPARK-22156](https://issues.apache.org/jira/browse/SPARK-22156):
+ The learning rate update for `Word2Vec` was incorrect when `numIterations` was set greater than
+ `1`. This will cause training results to be different between `2.3` and earlier versions.
+* [SPARK-21681](https://issues.apache.org/jira/browse/SPARK-21681):
+ Fixed an edge case bug in multinomial logistic regression that resulted in incorrect coefficients
+ when some features had zero variance.
+* [SPARK-16957](https://issues.apache.org/jira/browse/SPARK-16957):
+ Tree algorithms now use mid-points for split values. This may change results from model training.
+* [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657):
+ Fixed an issue where the features generated by `RFormula` without an intercept were inconsistent
+ with the output in R. This may change results from model training in this scenario.
## Previous Spark versions
diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md
index 687d7c893036..f4b0df58cf63 100644
--- a/docs/ml-migration-guides.md
+++ b/docs/ml-migration-guides.md
@@ -7,6 +7,29 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT
The migration guide for the current Spark version is kept on the [MLlib Guide main page](ml-guide.html#migration-guide).
+## From 2.1 to 2.2
+
+### Breaking changes
+
+There are no breaking changes.
+
+### Deprecations and changes of behavior
+
+**Deprecations**
+
+There are no deprecations.
+
+**Changes of behavior**
+
+* [SPARK-19787](https://issues.apache.org/jira/browse/SPARK-19787):
+ Default value of `regParam` changed from `1.0` to `0.1` for `ALS.train` method (marked `DeveloperApi`).
+ **Note** this does _not affect_ the `ALS` Estimator or Model, nor MLlib's `ALS` class.
+* [SPARK-14772](https://issues.apache.org/jira/browse/SPARK-14772):
+ Fixed inconsistency between Python and Scala APIs for `Param.copy` method.
+* [SPARK-11569](https://issues.apache.org/jira/browse/SPARK-11569):
+ `StringIndexer` now handles `NULL` values in the same way as unseen values. Previously an exception
+ would always be thrown regardless of the setting of the `handleInvalid` parameter.
+
## From 2.0 to 2.1
### Breaking changes
diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md
index aa92c0a37c0f..e22e9003c30f 100644
--- a/docs/ml-pipeline.md
+++ b/docs/ml-pipeline.md
@@ -188,9 +188,36 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s.
For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`.
This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`.
-## Saving and Loading Pipelines
+## ML persistence: Saving and Loading Pipelines
-Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported.
+Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API.
+As of Spark 2.3, the DataFrame-based API in `spark.ml` and `pyspark.ml` has complete coverage.
+
+ML persistence works across Scala, Java and Python. However, R currently uses a modified format,
+so models saved in R can only be loaded back in R; this should be fixed in the future and is
+tracked in [SPARK-15572](https://issues.apache.org/jira/browse/SPARK-15572).
+
+### Backwards compatibility for ML persistence
+
+In general, MLlib maintains backwards compatibility for ML persistence. I.e., if you save an ML
+model or Pipeline in one version of Spark, then you should be able to load it back and use it in a
+future version of Spark. However, there are rare exceptions, described below.
+
+Model persistence: Is a model or Pipeline saved using Apache Spark ML persistence in Spark
+version X loadable by Spark version Y?
+
+* Major versions: No guarantees, but best-effort.
+* Minor and patch versions: Yes; these are backwards compatible.
+* Note about the format: There are no guarantees for a stable persistence format, but model loading itself is designed to be backwards compatible.
+
+Model behavior: Does a model or Pipeline in Spark version X behave identically in Spark version Y?
+
+* Major versions: No guarantees, but best-effort.
+* Minor and patch versions: Identical behavior, except for bug fixes.
+
+For both model persistence and model behavior, any breaking changes across a minor version or patch
+version are reported in the Spark version release notes. If a breakage is not reported in release
+notes, then it should be treated as a bug to be fixed.
# Code examples
diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md
index 7f277543d2e9..8afea2c576e6 100644
--- a/docs/mllib-evaluation-metrics.md
+++ b/docs/mllib-evaluation-metrics.md
@@ -413,13 +413,13 @@ A ranking system usually deals with a set of $M$ users
$$U = \left\{u_0, u_1, ..., u_{M-1}\right\}$$
-Each user ($u_i$) having a set of $N$ ground truth relevant documents
+Each user ($u_i$) having a set of $N_i$ ground truth relevant documents
-$$D_i = \left\{d_0, d_1, ..., d_{N-1}\right\}$$
+$$D_i = \left\{d_0, d_1, ..., d_{N_i-1}\right\}$$
-And a list of $Q$ recommended documents, in order of decreasing relevance
+And a list of $Q_i$ recommended documents, in order of decreasing relevance
-$$R_i = \left[r_0, r_1, ..., r_{Q-1}\right]$$
+$$R_i = \left[r_0, r_1, ..., r_{Q_i-1}\right]$$
The goal of the ranking system is to produce the most relevant set of documents for each user. The relevance of the
sets and the effectiveness of the algorithms can be measured using the metrics listed below.
@@ -439,10 +439,10 @@ $$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{
Precision at k
|
- $p(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{k} \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} rel_{D_i}(R_i(j))}$
+ $p(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{k} \sum_{j=0}^{\text{min}(Q_i, k) - 1} rel_{D_i}(R_i(j))}$
|
- Precision at k is a measure of
+ Precision at k is a measure of
how many of the first k recommended documents are in the set of true relevant documents averaged across all
users. In this metric, the order of the recommendations is not taken into account.
|
@@ -450,10 +450,10 @@ $$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{