Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state

import java.io._
import java.util.Locale
import java.util.concurrent.atomic.LongAdder

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand Down Expand Up @@ -164,7 +165,16 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
}

override def metrics: StateStoreMetrics = {
StateStoreMetrics(mapToUpdate.size(), SizeEstimator.estimate(mapToUpdate), Map.empty)
// NOTE: we provide estimation of cache size as "memoryUsedBytes", and size of state for
// current version as "stateOnCurrentVersionSizeBytes"
val metricsFromProvider: Map[String, Long] = getMetricsForProvider()

val customMetrics = metricsFromProvider.flatMap { case (name, value) =>
// just allow searching from list cause the list is small enough
supportedCustomMetrics.find(_.name == name).map(_ -> value)
} + (metricStateOnCurrentVersionSizeBytes -> SizeEstimator.estimate(mapToUpdate))

StateStoreMetrics(mapToUpdate.size(), metricsFromProvider("memoryUsedBytes"), customMetrics)
}

/**
Expand All @@ -179,6 +189,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
}
}

def getMetricsForProvider(): Map[String, Long] = synchronized {
Map("memoryUsedBytes" -> SizeEstimator.estimate(loadedMaps),
metricLoadedMapCacheHit.name -> loadedMapCacheHitCount.sum(),
metricLoadedMapCacheMiss.name -> loadedMapCacheMissCount.sum())
}

/** Get the state store for making updates to create a new `version` of the store. */
override def getStore(version: Long): StateStore = synchronized {
require(version >= 0, "Version cannot be less than 0")
Expand Down Expand Up @@ -224,7 +240,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
}

override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = {
Nil
metricStateOnCurrentVersionSizeBytes :: metricLoadedMapCacheHit :: metricLoadedMapCacheMiss ::
Nil
}

override def toString(): String = {
Expand All @@ -245,6 +262,21 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf)
private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)

private val loadedMapCacheHitCount: LongAdder = new LongAdder
private val loadedMapCacheMissCount: LongAdder = new LongAdder

private lazy val metricStateOnCurrentVersionSizeBytes: StateStoreCustomSizeMetric =
StateStoreCustomSizeMetric("stateOnCurrentVersionSizeBytes",
"estimated size of state only on current version")

private lazy val metricLoadedMapCacheHit: StateStoreCustomMetric =
StateStoreCustomSumMetric("loadedMapCacheHitCount",
"count of cache hit on states cache in provider")

private lazy val metricLoadedMapCacheMiss: StateStoreCustomMetric =
StateStoreCustomSumMetric("loadedMapCacheMissCount",
"count of cache miss on states cache in provider")

private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean)

private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = {
Expand Down Expand Up @@ -276,13 +308,16 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
// Shortcut if the map for this version is already there to avoid a redundant put.
val loadedCurrentVersionMap = synchronized { loadedMaps.get(version) }
if (loadedCurrentVersionMap.isDefined) {
loadedMapCacheHitCount.increment()
return loadedCurrentVersionMap.get
}

logWarning(s"The state for version $version doesn't exist in loadedMaps. " +
"Reading snapshot file and delta files if needed..." +
"Note that this is normal for the first batch of starting query.")

loadedMapCacheMissCount.increment()

val (result, elapsedMs) = Utils.timeTakenMs {
val snapshotCurrentVersionMap = readSnapshotFile(version)
if (snapshotCurrentVersionMap.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ trait StateStoreCustomMetric {
def name: String
def desc: String
}

case class StateStoreCustomSumMetric(name: String, desc: String) extends StateStoreCustomMetric
case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric
case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ class SymmetricHashJoinStateManager(
keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once
keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes,
keyWithIndexToValueMetrics.customMetrics.map {
case (s @ StateStoreCustomSumMetric(_, desc), value) =>
s.copy(desc = newDesc(desc)) -> value
case (s @ StateStoreCustomSizeMetric(_, desc), value) =>
s.copy(desc = newDesc(desc)) -> value
case (s @ StateStoreCustomTimingMetric(_, desc), value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,18 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
* the driver after this SparkPlan has been executed and metrics have been updated.
*/
def getProgress(): StateOperatorProgress = {
val customMetrics = stateStoreCustomMetrics
.map(entry => entry._1 -> longMetric(entry._1).value)

val javaConvertedCustomMetrics: java.util.HashMap[String, java.lang.Long] =
new java.util.HashMap(customMetrics.mapValues(long2Long).asJava)

new StateOperatorProgress(
numRowsTotal = longMetric("numTotalStateRows").value,
numRowsUpdated = longMetric("numUpdatedStateRows").value,
memoryUsedBytes = longMetric("stateMemory").value)
memoryUsedBytes = longMetric("stateMemory").value,
javaConvertedCustomMetrics
)
}

/** Records the duration of running `body` for the next query progress update. */
Expand All @@ -115,6 +123,8 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
private def stateStoreCustomMetrics: Map[String, SQLMetric] = {
val provider = StateStoreProvider.create(sqlContext.conf.stateStoreProviderClass)
provider.supportedCustomMetrics.map {
case StateStoreCustomSumMetric(name, desc) =>
name -> SQLMetrics.createMetric(sparkContext, desc)
case StateStoreCustomSizeMetric(name, desc) =>
name -> SQLMetrics.createSizeMetric(sparkContext, desc)
case StateStoreCustomTimingMetric(name, desc) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ import org.apache.spark.annotation.InterfaceStability
class StateOperatorProgress private[sql](
val numRowsTotal: Long,
val numRowsUpdated: Long,
val memoryUsedBytes: Long
val memoryUsedBytes: Long,
val customMetrics: ju.Map[String, JLong] = new ju.HashMap()
) extends Serializable {

/** The compact JSON representation of this progress. */
Expand All @@ -48,12 +49,20 @@ class StateOperatorProgress private[sql](
def prettyJson: String = pretty(render(jsonValue))

private[sql] def copy(newNumRowsUpdated: Long): StateOperatorProgress =
new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes)
new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes, customMetrics)

private[sql] def jsonValue: JValue = {
("numRowsTotal" -> JInt(numRowsTotal)) ~
("numRowsUpdated" -> JInt(numRowsUpdated)) ~
("memoryUsedBytes" -> JInt(memoryUsedBytes))
("memoryUsedBytes" -> JInt(memoryUsedBytes)) ~
("customMetrics" -> {
if (!customMetrics.isEmpty) {
val keys = customMetrics.keySet.asScala.toSeq.sorted
keys.map { k => k -> JInt(customMetrics.get(k).toLong) : JObject }.reduce(_ ~ _)
} else {
JNothing
}
})
}

override def toString: String = prettyJson
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,22 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
assert(store.metrics.memoryUsedBytes > noDataMemoryUsed)
}

test("reports memory usage on current version") {
def getSizeOfStateForCurrentVersion(metrics: StateStoreMetrics): Long = {
val metricPair = metrics.customMetrics.find(_._1.name == "stateOnCurrentVersionSizeBytes")
assert(metricPair.isDefined)
metricPair.get._2
}

val provider = newStoreProvider()
val store = provider.getStore(0)
val noDataMemoryUsed = getSizeOfStateForCurrentVersion(store.metrics)

put(store, "a", 1)
store.commit()
assert(getSizeOfStateForCurrentVersion(store.metrics) > noDataMemoryUsed)
}

test("StateStore.get") {
quietly {
val dir = newDir()
Expand Down Expand Up @@ -507,6 +523,90 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
assert(CreateAtomicTestManager.cancelCalledInCreateAtomic)
}

test("expose metrics with custom metrics to StateStoreMetrics") {
def getCustomMetric(metrics: StateStoreMetrics, name: String): Long = {
val metricPair = metrics.customMetrics.find(_._1.name == name)
assert(metricPair.isDefined)
metricPair.get._2
}

def getLoadedMapSizeMetric(metrics: StateStoreMetrics): Long = {
metrics.memoryUsedBytes
}

def assertCacheHitAndMiss(
metrics: StateStoreMetrics,
expectedCacheHitCount: Long,
expectedCacheMissCount: Long): Unit = {
val cacheHitCount = getCustomMetric(metrics, "loadedMapCacheHitCount")
val cacheMissCount = getCustomMetric(metrics, "loadedMapCacheMissCount")
assert(cacheHitCount === expectedCacheHitCount)
assert(cacheMissCount === expectedCacheMissCount)
}

val provider = newStoreProvider()

// Verify state before starting a new set of updates
assert(getLatestData(provider).isEmpty)

val store = provider.getStore(0)
assert(!store.hasCommitted)

assert(store.metrics.numKeys === 0)

val initialLoadedMapSize = getLoadedMapSizeMetric(store.metrics)
assert(initialLoadedMapSize >= 0)
assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0)

put(store, "a", 1)
assert(store.metrics.numKeys === 1)

put(store, "b", 2)
put(store, "aa", 3)
assert(store.metrics.numKeys === 3)
remove(store, _.startsWith("a"))
assert(store.metrics.numKeys === 1)
assert(store.commit() === 1)

assert(store.hasCommitted)

val loadedMapSizeForVersion1 = getLoadedMapSizeMetric(store.metrics)
assert(loadedMapSizeForVersion1 > initialLoadedMapSize)
assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0)

val storeV2 = provider.getStore(1)
assert(!storeV2.hasCommitted)
assert(storeV2.metrics.numKeys === 1)

put(storeV2, "cc", 4)
assert(storeV2.metrics.numKeys === 2)
assert(storeV2.commit() === 2)

assert(storeV2.hasCommitted)

val loadedMapSizeForVersion1And2 = getLoadedMapSizeMetric(storeV2.metrics)
assert(loadedMapSizeForVersion1And2 > loadedMapSizeForVersion1)
assertCacheHitAndMiss(storeV2.metrics, expectedCacheHitCount = 1, expectedCacheMissCount = 0)

val reloadedProvider = newStoreProvider(store.id)
// intended to load version 2 instead of 1
// version 2 will not be loaded to the cache in provider
val reloadedStore = reloadedProvider.getStore(1)
assert(reloadedStore.metrics.numKeys === 1)

assert(getLoadedMapSizeMetric(reloadedStore.metrics) === loadedMapSizeForVersion1)
assertCacheHitAndMiss(reloadedStore.metrics, expectedCacheHitCount = 0,
expectedCacheMissCount = 1)

// now we are loading version 2
val reloadedStoreV2 = reloadedProvider.getStore(2)
assert(reloadedStoreV2.metrics.numKeys === 2)

assert(getLoadedMapSizeMetric(reloadedStoreV2.metrics) > loadedMapSizeForVersion1)
assertCacheHitAndMiss(reloadedStoreV2.metrics, expectedCacheHitCount = 0,
expectedCacheMissCount = 2)
}

override def newStoreProvider(): HDFSBackedStateStoreProvider = {
newStoreProvider(opId = Random.nextInt(), partition = 0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
test("event ordering") {
val listener = new EventCollector
withListenerAdded(listener) {
for (i <- 1 to 100) {
for (i <- 1 to 50) {
Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Jun 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the patch this test starts failing: it just means there's more time needed to run this loop 100 times. It doesn't mean the logic is broken. Decreasing number works for me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, and I agree with the implicit claim that this slowdown isn't too worrying.

listener.reset()
require(listener.startEvent === null)
testStream(MemoryStream[Int].toDS)(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
| "stateOperators" : [ {
| "numRowsTotal" : 0,
| "numRowsUpdated" : 1,
| "memoryUsedBytes" : 2
| "memoryUsedBytes" : 3,
| "customMetrics" : {
| "loadedMapCacheHitCount" : 1,
| "loadedMapCacheMissCount" : 0,
| "stateOnCurrentVersionSizeBytes" : 2
| }
| } ],
| "sources" : [ {
| "description" : "source",
Expand Down Expand Up @@ -230,7 +235,11 @@ object StreamingQueryStatusAndProgressSuite {
"avg" -> "2016-12-05T20:54:20.827Z",
"watermark" -> "2016-12-05T20:54:20.827Z").asJava),
stateOperators = Array(new StateOperatorProgress(
numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2)),
numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 3,
customMetrics = new java.util.HashMap(Map("stateOnCurrentVersionSizeBytes" -> 2L,
"loadedMapCacheHitCount" -> 1L, "loadedMapCacheMissCount" -> 0L)
.mapValues(long2Long).asJava)
)),
sources = Array(
new SourceProgress(
description = "source",
Expand Down