Skip to content

Commit 7dd973d

Browse files
huanliwang-dbanishshri-db
authored andcommitted
[SPARK-54118][SS] Improve the put/merge operation in ListState when there are multiple values
In SS TWS, when we do the put(array) operation in liststate, we put the first element and then merge the remaining elements one by one. so if we want to put an array with 100 elements, it means we need do 1 put + 99 merges. This can result in worse performance than a single put operation for the entire array. Similar, we have the same issue in merge(array) Ran the benchmark with inputRate = 1M keys/second - 1M key cardinality, here are the results for the batch duration with TWS in SS Before: ``` Batch Duration (ms) p50 666.00 p90 899.70 p95 969.35 p99 1081.94 ``` After ``` Batch Duration (ms) p50 488 p90 576 p95 609 p99 713 ``` ### What changes were proposed in this pull request? Improve the existing `put(array)` and `merge(array)` implementation to reduce the number of rocksdb operations. ### Why are the changes needed? performance improvement ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing UT and new UT ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#52820 from huanliwang-db/huanliwang-db/improve-list-state. Authored-by: huanliwang-db <[email protected]> Signed-off-by: Anish Shrigondekar <[email protected]>
1 parent ac717dd commit 7dd973d

File tree

8 files changed

+255
-12
lines changed

8 files changed

+255
-12
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statevariables/ListStateImpl.scala

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,15 @@ class ListStateImpl[S](
8888
validateNewState(newState)
8989

9090
val encodedKey = stateTypesEncoder.encodeGroupingKey()
91-
var isFirst = true
9291
var entryCount = 0L
9392
TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows")
9493

95-
newState.foreach { v =>
96-
val encodedValue = stateTypesEncoder.encodeValue(v)
97-
if (isFirst) {
98-
store.put(encodedKey, encodedValue, stateName)
99-
isFirst = false
100-
} else {
101-
store.merge(encodedKey, encodedValue, stateName)
102-
}
94+
val encodedValues = newState.map { v =>
10395
entryCount += 1
10496
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
97+
stateTypesEncoder.encodeValue(v).copy()
10598
}
99+
store.putList(encodedKey, encodedValues, stateName)
106100
updateEntryCount(encodedKey, entryCount)
107101
}
108102

@@ -123,12 +117,12 @@ class ListStateImpl[S](
123117

124118
val encodedKey = stateTypesEncoder.encodeGroupingKey()
125119
var entryCount = getEntryCount(encodedKey)
126-
newState.foreach { v =>
127-
val encodedValue = stateTypesEncoder.encodeValue(v)
128-
store.merge(encodedKey, encodedValue, stateName)
120+
val encodedValues = newState.map { v =>
129121
entryCount += 1
130122
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
123+
stateTypesEncoder.encodeValue(v).copy()
131124
}
125+
store.mergeList(encodedKey, encodedValues, stateName)
132126
updateEntryCount(encodedKey, entryCount)
133127
}
134128

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,20 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
292292
throw StateStoreErrors.unsupportedOperationException("multipleValuesPerKey", providerName)
293293
}
294294

295+
override def putList(key: UnsafeRow, values: Array[UnsafeRow], colFamilyName: String): Unit = {
296+
throw StateStoreErrors.unsupportedOperationException("putList", providerName)
297+
}
298+
295299
override def merge(key: UnsafeRow,
296300
value: UnsafeRow,
297301
colFamilyName: String): Unit = {
298302
throw StateStoreErrors.unsupportedOperationException("merge", providerName)
299303
}
304+
305+
override def mergeList(
306+
key: UnsafeRow, values: Array[UnsafeRow], colFamilyName: String): Unit = {
307+
throw StateStoreErrors.unsupportedOperationException("mergeList", providerName)
308+
}
300309
}
301310

302311
def getMetricsForProvider(): Map[String, Long] = synchronized {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import org.apache.spark.TaskContext
4040
import org.apache.spark.internal.{LogEntry, Logging, LogKeys}
4141
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
4242
import org.apache.spark.sql.errors.QueryExecutionErrors
43+
import org.apache.spark.unsafe.Platform
4344
import org.apache.spark.util.{NextIterator, Utils}
4445

4546
// RocksDB operations that could acquire/release the instance lock
@@ -1050,6 +1051,69 @@ class RocksDB(
10501051
changelogWriter.foreach(_.put(keyWithPrefix, valueWithChecksum))
10511052
}
10521053

1054+
/**
1055+
* Convert the given list of value row bytes into a single byte array. The returned array
1056+
* bytes supports additional values to be later merged to it.
1057+
*/
1058+
private def getListValuesInArrayByte(values: List[Array[Byte]]): Array[Byte] = {
1059+
// Delimit each value row bytes with a single byte delimiter, the last
1060+
// value row won't have a delimiter at the end.
1061+
val delimiterNum = values.length - 1
1062+
// The bytes in values already include the bytes length prefix
1063+
val totalSize = values.map(_.length).sum +
1064+
delimiterNum // for each delimiter
1065+
1066+
val result = new Array[Byte](totalSize)
1067+
var pos = Platform.BYTE_ARRAY_OFFSET
1068+
1069+
values.zipWithIndex.foreach { case (rowBytes, idx) =>
1070+
// Write the data
1071+
Platform.copyMemory(rowBytes, Platform.BYTE_ARRAY_OFFSET, result, pos, rowBytes.length)
1072+
pos += rowBytes.length
1073+
1074+
// Add the delimiter - we are using "," as the delimiter
1075+
if (idx < delimiterNum) {
1076+
result(pos - Platform.BYTE_ARRAY_OFFSET) = 44.toByte
1077+
}
1078+
// Move the position for delimiter
1079+
pos += 1
1080+
}
1081+
result
1082+
}
1083+
1084+
/**
1085+
* Put the given list of values for the given key.
1086+
* @note
1087+
* This update is not committed to disk until commit() is called.
1088+
*/
1089+
def putList(
1090+
key: Array[Byte],
1091+
values: List[Array[Byte]],
1092+
cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
1093+
includesPrefix: Boolean = false,
1094+
deriveCfName: Boolean = false): Unit = {
1095+
updateMemoryUsageIfNeeded()
1096+
val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
1097+
encodeStateRowWithPrefix(key, cfName)
1098+
} else {
1099+
key
1100+
}
1101+
1102+
val valuesInArrayByte = getListValuesInArrayByte(values)
1103+
1104+
val columnFamilyName = if (deriveCfName && useColumnFamilies) {
1105+
val (_, cfName) = decodeStateRowWithPrefix(keyWithPrefix)
1106+
cfName
1107+
} else {
1108+
cfName
1109+
}
1110+
1111+
handleMetricsUpdate(keyWithPrefix, columnFamilyName, isPutOrMerge = true)
1112+
db.put(writeOptions, keyWithPrefix, valuesInArrayByte)
1113+
changelogWriter.foreach(_.put(keyWithPrefix, valuesInArrayByte))
1114+
}
1115+
1116+
10531117
/**
10541118
* Merge the given value for the given key. This is equivalent to the Atomic
10551119
* Read-Modify-Write operation in RocksDB, known as the "Merge" operation. The
@@ -1094,6 +1158,39 @@ class RocksDB(
10941158
changelogWriter.foreach(_.merge(keyWithPrefix, valueWithChecksum))
10951159
}
10961160

1161+
/**
1162+
* Merge the given list of values for the given key.
1163+
*
1164+
* This is similar to the merge() function, but allows merging multiple values at once. The
1165+
* provided values will be appended to the current list of values for the given key.
1166+
*/
1167+
def mergeList(
1168+
key: Array[Byte],
1169+
values: List[Array[Byte]],
1170+
cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
1171+
includesPrefix: Boolean = false,
1172+
deriveCfName: Boolean = false): Unit = {
1173+
updateMemoryUsageIfNeeded()
1174+
val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
1175+
encodeStateRowWithPrefix(key, cfName)
1176+
} else {
1177+
key
1178+
}
1179+
1180+
val columnFamilyName = if (deriveCfName && useColumnFamilies) {
1181+
val (_, cfName) = decodeStateRowWithPrefix(keyWithPrefix)
1182+
cfName
1183+
} else {
1184+
cfName
1185+
}
1186+
1187+
val valueInArrayByte = getListValuesInArrayByte(values)
1188+
1189+
handleMetricsUpdate(keyWithPrefix, columnFamilyName, isPutOrMerge = true)
1190+
db.merge(writeOptions, keyWithPrefix, valueInArrayByte)
1191+
changelogWriter.foreach(_.merge(keyWithPrefix, valueInArrayByte))
1192+
}
1193+
10971194
/**
10981195
* Remove the key if present.
10991196
* @note This update is not committed to disk until commit() is called.

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,31 @@ private[sql] class RocksDBStateStoreProvider
300300
rocksDB.merge(keyEncoder.encodeKey(key), valueEncoder.encodeValue(value), colFamilyName)
301301
}
302302

303+
override def mergeList(
304+
key: UnsafeRow,
305+
values: Array[UnsafeRow],
306+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
307+
validateAndTransitionState(UPDATE)
308+
verify(state == UPDATING, "Cannot merge after already committed or aborted")
309+
verifyColFamilyOperations("merge", colFamilyName)
310+
311+
val kvEncoder = keyValueEncoderMap.get(colFamilyName)
312+
val keyEncoder = kvEncoder._1
313+
val valueEncoder = kvEncoder._2
314+
verify(
315+
valueEncoder.supportsMultipleValuesPerKey,
316+
"Merge operation requires an encoder" +
317+
" which supports multiple values for a single key")
318+
verify(key != null, "Key cannot be null")
319+
require(values != null, "Cannot merge a null value")
320+
values.foreach(v => require(v != null, "Cannot merge a null value in the array"))
321+
322+
rocksDB.mergeList(
323+
keyEncoder.encodeKey(key),
324+
values.map(valueEncoder.encodeValue).toList,
325+
colFamilyName)
326+
}
327+
303328
override def put(key: UnsafeRow, value: UnsafeRow, colFamilyName: String): Unit = {
304329
validateAndTransitionState(UPDATE)
305330
verify(state == UPDATING, "Cannot put after already committed or aborted")
@@ -311,6 +336,28 @@ private[sql] class RocksDBStateStoreProvider
311336
rocksDB.put(kvEncoder._1.encodeKey(key), kvEncoder._2.encodeValue(value), colFamilyName)
312337
}
313338

339+
override def putList(
340+
key: UnsafeRow,
341+
values: Array[UnsafeRow],
342+
colFamilyName: String): Unit = {
343+
validateAndTransitionState(UPDATE)
344+
verify(state == UPDATING, "Cannot put after already committed or aborted")
345+
verify(key != null, "Key cannot be null")
346+
require(values != null, "Cannot put a null value")
347+
values.foreach(v => require(v != null, "Cannot put a null value in the array"))
348+
verifyColFamilyOperations("put", colFamilyName)
349+
350+
val kvEncoder = keyValueEncoderMap.get(colFamilyName)
351+
verify(
352+
kvEncoder._2.supportsMultipleValuesPerKey,
353+
"Multi-value put operation requires an encoder" +
354+
" which supports multiple values for a single key")
355+
rocksDB.putList(
356+
kvEncoder._1.encodeKey(key),
357+
values.map(kvEncoder._2.encodeValue).toList,
358+
colFamilyName)
359+
}
360+
314361
override def remove(key: UnsafeRow, colFamilyName: String): Unit = {
315362
validateAndTransitionState(UPDATE)
316363
verify(state == UPDATING, "Cannot remove after already committed or aborted")

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,16 @@ trait StateStore extends ReadStateStore {
208208
value: UnsafeRow,
209209
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit
210210

211+
/**
212+
* Put a new list of non-null values for a non-null key. Implementations must be aware that the
213+
* UnsafeRows in the params can be reused, and must make copies of the data as needed for
214+
* persistence.
215+
*/
216+
def putList(
217+
key: UnsafeRow,
218+
values: Array[UnsafeRow],
219+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit
220+
211221
/**
212222
* Remove a single non-null key.
213223
*/
@@ -225,6 +235,18 @@ trait StateStore extends ReadStateStore {
225235
def merge(key: UnsafeRow, value: UnsafeRow,
226236
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit
227237

238+
/**
239+
* Merges the provided list of values with existing values of a non-null key. If a existing
240+
* value does not exist, this operation behaves as [[StateStore.putArray()]].
241+
*
242+
* It is expected to throw exception if Spark calls this method without setting
243+
* multipleValuesPerKey as true for the column family.
244+
*/
245+
def mergeList(
246+
key: UnsafeRow,
247+
values: Array[UnsafeRow],
248+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit
249+
228250
/**
229251
* Commit all the updates that have been made to the store, and return the new version.
230252
* Implementations should ensure that no more updates (puts, removes) can be after a commit in

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ class MemoryStateStore extends StateStore() {
5151
override def put(key: UnsafeRow, newValue: UnsafeRow, colFamilyName: String): Unit =
5252
map.put(key.copy(), newValue.copy())
5353

54+
override def putList(key: UnsafeRow, newValues: Array[UnsafeRow], colFamilyName: String): Unit = {
55+
throw new UnsupportedOperationException("Doesn't support put multiple values put")
56+
}
57+
5458
override def remove(key: UnsafeRow, colFamilyName: String): Unit = map.remove(key)
5559

5660
override def commit(): Long = version + 1
@@ -78,6 +82,10 @@ class MemoryStateStore extends StateStore() {
7882
throw new UnsupportedOperationException("Doesn't support multiple values per key")
7983
}
8084

85+
override def mergeList(key: UnsafeRow, values: Array[UnsafeRow], colFamilyName: String): Unit = {
86+
throw new UnsupportedOperationException("Doesn't support multiple values merge")
87+
}
88+
8189
override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
8290
throw new UnsupportedOperationException("Doesn't support multiple values per key")
8391
}

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta
123123
innerStore.put(key, value, colFamilyName)
124124
}
125125

126+
override def putList(
127+
key: UnsafeRow,
128+
values: Array[UnsafeRow],
129+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
130+
innerStore.putList(key, values, colFamilyName)
131+
}
132+
126133
override def remove(
127134
key: UnsafeRow,
128135
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
@@ -136,6 +143,13 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta
136143
innerStore.merge(key, value, colFamilyName)
137144
}
138145

146+
override def mergeList(
147+
key: UnsafeRow,
148+
values: Array[UnsafeRow],
149+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
150+
innerStore.mergeList(key, values, colFamilyName)
151+
}
152+
139153
override def commit(): Long = innerStore.commit()
140154
override def metrics: StateStoreMetrics = innerStore.metrics
141155
override def getStateStoreCheckpointInfo(): StateStoreCheckpointInfo = {

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2016,6 +2016,58 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession
20162016
}
20172017
}
20182018

2019+
test("RocksDB: ensure putList / mergeList operation correctness") {
2020+
withTempDir { dir =>
2021+
val remoteDir = Utils.createTempDir().toString
2022+
// minDeltasForSnapshot being 5 ensures that only changelog files are created
2023+
// for the 3 commits below
2024+
val conf = dbConf.copy(minDeltasForSnapshot = 5, compactOnCommit = false)
2025+
new File(remoteDir).delete() // to make sure that the directory gets created
2026+
withDB(remoteDir, conf = conf, useColumnFamilies = true) { db =>
2027+
db.load(0)
2028+
db.put("a", "1".getBytes)
2029+
db.mergeList("a", Seq("2", "3", "4").map(_.getBytes).toList)
2030+
db.commit()
2031+
2032+
db.load(1)
2033+
db.mergeList("a", Seq("5", "6").map(_.getBytes).toList)
2034+
db.commit()
2035+
2036+
db.load(2)
2037+
db.remove("a")
2038+
db.commit()
2039+
2040+
db.load(3)
2041+
db.putList("a", Seq("7", "8", "9").map(_.getBytes).toList)
2042+
db.commit()
2043+
2044+
db.load(4)
2045+
db.putList("a", Seq("10", "11").map(_.getBytes).toList)
2046+
db.commit()
2047+
2048+
db.load(1)
2049+
assert(new String(db.get("a")) === "1,2,3,4")
2050+
assert(db.iterator().map(toStr).toSet === Set(("a", "1,2,3,4")))
2051+
2052+
db.load(2)
2053+
assert(new String(db.get("a")) === "1,2,3,4,5,6")
2054+
assert(db.iterator().map(toStr).toSet === Set(("a", "1,2,3,4,5,6")))
2055+
2056+
db.load(3)
2057+
assert(db.get("a") === null)
2058+
assert(db.iterator().isEmpty)
2059+
2060+
db.load(4)
2061+
assert(new String(db.get("a")) === "7,8,9")
2062+
assert(db.iterator().map(toStr).toSet === Set(("a", "7,8,9")))
2063+
2064+
db.load(5)
2065+
assert(new String(db.get("a")) === "10,11")
2066+
assert(db.iterator().map(toStr).toSet === Set(("a", "10,11")))
2067+
}
2068+
}
2069+
}
2070+
20192071
testWithStateStoreCheckpointIdsAndColumnFamilies("RocksDBFileManager: delete orphan files",
20202072
TestWithBothChangelogCheckpointingEnabledAndDisabled) {
20212073
case (enableStateStoreCheckpointIds, colFamiliesEnabled) =>

0 commit comments

Comments
 (0)