Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,39 @@ class TaskMetrics extends Serializable {
* Storage statuses of any blocks that have been updated as a result of this task.
*/
var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None

/**
* Custom task-specific metrics. Any piece of Spark machinery that cares to track custom task-related metrics can
* now use the setCustomMetric(name,value) method on TaskMetrics to do that.
* e.g. Any RDD that wants to track a custom metric of its interest (related to execution of a task) can
* call setCustomMetric(name,value) inside its compute() method.
* Map Key -> name of custom metric
* Map Value -> list of numeric values for custom metric
*/
private val _customMetrics: scala.collection.mutable.HashMap[String, List[Long]] = scala.collection.mutable.HashMap()

/**
* Custom task-specific metrics. Any piece of Spark machinery that cares to track custom task-related metrics can
* now use the setCustomMetric(name,value) method on TaskMetrics to do that.
* e.g. Any RDD that wants to track a custom metric of its interest (related to execution of a task) can
* call setCustomMetric(name,value) inside its compute() method.
* Map Key -> name of custom metric
* Map Value -> list of numeric values for custom metric
*/
def customMetrics = _customMetrics

/**
* Convenience method for setting a custom metric
* @param metricName name of custom metric
* @param metricValue value for custom metric
*/
def setCustomMetric(metricName: String, metricValue: Long) {
if (_customMetrics.contains(metricName)) {
_customMetrics(metricName) = _customMetrics(metricName) ++ List(metricValue)
} else {
_customMetrics(metricName) = List(metricValue)
}
}
}

private[spark] object TaskMetrics {
Expand Down
8 changes: 7 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
// scalastyle:on
val taskHeaders: Seq[String] =
Seq("Task Index", "Task ID", "Status", "Locality Level", "Executor", "Launch Time") ++
Seq("Duration", "GC Time", "Result Ser Time") ++
Seq("Duration", "GC Time", "Result Ser Time", "Custom Metrics") ++
{if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++
{if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++
{if (hasBytesSpilled) Seq("Shuffle Spill (Memory)", "Shuffle Spill (Disk)") else Nil} ++
Expand Down Expand Up @@ -220,6 +220,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("")
val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L)
val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
val customMetrics = metrics.map(_.customMetrics).getOrElse(scala.collection.mutable.HashMap())

val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead)
val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("")
Expand Down Expand Up @@ -261,6 +262,11 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
<td sorttable_customkey={serializationTime.toString}>
{if (serializationTime > 0) UIUtils.formatDuration(serializationTime) else ""}
</td>
<td>
{customMetrics.foldLeft("")( (previous, pair) => {
(if (previous.isEmpty) "" else previous + "\n") + pair._1 + " = " + pair._2.mkString(",")
})}
</td>
{if (shuffleRead) {
<td sorttable_customkey={shuffleReadSortable}>
{shuffleReadReadable}
Expand Down
18 changes: 17 additions & 1 deletion core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ private[spark] object JsonProtocol {
("Status" -> blockStatusToJson(status))
})
}.getOrElse(JNothing)
val customMetrics =
taskMetrics.customMetrics.map { case (name, value) =>
customMetricToJson(name, value)
}
("Host Name" -> taskMetrics.hostname) ~
("Executor Deserialize Time" -> taskMetrics.executorDeserializeTime) ~
("Executor Run Time" -> taskMetrics.executorRunTime) ~
Expand All @@ -225,7 +229,12 @@ private[spark] object JsonProtocol {
("Disk Bytes Spilled" -> taskMetrics.diskBytesSpilled) ~
("Shuffle Read Metrics" -> shuffleReadMetrics) ~
("Shuffle Write Metrics" -> shuffleWriteMetrics) ~
("Updated Blocks" -> updatedBlocks)
("Updated Blocks" -> updatedBlocks) ~
("Custom Metrics" -> customMetrics)
}

def customMetricToJson(metricName: String, metricValue: List[Long]): JValue = {
metricName -> metricValue
}

def shuffleReadMetricsToJson(shuffleReadMetrics: ShuffleReadMetrics): JValue = {
Expand Down Expand Up @@ -527,6 +536,13 @@ private[spark] object JsonProtocol {
(id, status)
}
}

// set customMetrics
for (
obj <- (json \ "Custom Metrics").extract[List[JValue]];
(metricName, JArray(v)) <- obj.asInstanceOf[JObject].obj;
metricValue <- v.map(_.extract[Long])
) metrics.setCustomMetric(metricName, metricValue)
metrics
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ class JsonProtocolSuite extends FunSuite {
testEvent(applicationEnd, applicationEndJsonString)
}

test("Custom Metrics in TaskMetrics") {
val taskMetric = makeTaskMetrics(33333L, 44444L, 55555L, 66666L, 7, 8)
taskMetric.setCustomMetric("Custom TaskMetric 1", 1)
taskMetric.setCustomMetric("Custom TaskMetric 2", 2)
testTaskMetrics(taskMetric)
}

test("Dependent Classes") {
testRDDInfo(makeRddInfo(2, 3, 4, 5L, 6L))
testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L))
Expand Down Expand Up @@ -280,6 +287,7 @@ class JsonProtocolSuite extends FunSuite {
assertOptionEquals(
metrics1.shuffleWriteMetrics, metrics2.shuffleWriteMetrics, assertShuffleWriteEquals)
assertOptionEquals(metrics1.updatedBlocks, metrics2.updatedBlocks, assertBlocksEquals)
assert(metrics1.customMetrics === metrics2.customMetrics)
}

private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics) {
Expand Down