Skip to content

Commit 07e08c0

Browse files
huanliwang-dbHeartSaVioR
authored andcommitted
[SPARK-48105][SS][3.5] Fix the race condition between state store unloading and snapshotting
* When we close the hdfs state store, we should only remove the entry from `loadedMaps` rather than doing the active data cleanup. JVM GC should be able to help us GC those objects. * we should wait for the maintenance thread to stop before unloading the providers. There are two race conditions between state store snapshotting and state store unloading which could result in query failure and potential data corruption. Case 1: 1. the maintenance thread pool encounters some issues and call the [stopMaintenanceTask,](https://github.com/apache/spark/blob/d9d79a54a3cd487380039c88ebe9fa708e0dcf23/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala#L774) this function further calls [threadPool.stop.](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala#L587) However, this function doesn't wait for the stop operation to be completed and move to do the state store [unload and clear.](https://github.com/apache/spark/blob/d9d79a54a3cd487380039c88ebe9fa708e0dcf23/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala#L775-L778) 2. the provider unload will [close the state store](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala#L719-L721) which [clear the values of loadedMaps](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala#L353-L355) for HDFS backed state store. 3. if the not-yet-stop maintenance thread is still running and trying to do the snapshot, but the data in the underlying `HDFSBackedStateStoreMap` has been removed. if this snapshot process completes successfully, then we will write corrupted data and the following batches will consume this corrupted data. Case 2: 1. In executor_1, the maintenance thread is going to do the snapshot for state_store_1, it retrieves the `HDFSBackedStateStoreMap` object from the loadedMaps, after this, the maintenance thread [releases the lock of the loadedMaps](https://github.com/apache/spark/blob/c6696cdcd611a682ebf5b7a183e2970ecea3b58c/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala#L750-L751). 2. state_store_1 is loaded in another executor, e.g. executor_2. 3. another state store, state_store_2, is loaded on executor_1 and [reportActiveStoreInstance](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala#L854-L871) to driver. 4. executor_1 does the [unload](https://github.com/apache/spark/blob/c6696cdcd611a682ebf5b7a183e2970ecea3b58c/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala#L713) for those no longer active state store which clears the data entries in the `HDFSBackedStateStoreMap` 5. the snapshotting thread is terminated and uploads the incomplete snapshot to cloud because the [iterator doesn't have next element](https://github.com/apache/spark/blob/c6696cdcd611a682ebf5b7a183e2970ecea3b58c/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala#L634) after doing the clear. 6. future batches are consuming the corrupted data. No ``` [info] Run completed in 2 minutes, 55 seconds. [info] Total number of tests run: 153 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 153, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. [success] Total time: 271 s (04:31), completed May 2, 2024, 6:26:33 PM ``` before this change ``` [info] - state store unload/close happens during the maintenance *** FAILED *** (648 milliseconds) [info] Vector("a1", "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", "a2", "a20", "a3", "a4", "a5", "a6", "a7", "a8", "a9") did not equal ArrayBuffer("a8") (StateStoreSuite.scala:414) [info] Analysis: [info] Vector1(0: "a1" -> "a8", 1: "a10" -> , 2: "a11" -> , 3: "a12" -> , 4: "a13" -> , 5: "a14" -> , 6: "a15" -> , 7: "a16" -> , 8: "a17" -> , 9: "a18" -> , 10: "a19" -> , 11: "a2" -> , 12: "a20" -> , 13: "a3" -> , 14: "a4" -> , 15: "a5" -> , 16: "a6" -> , 17: "a7" -> , 18: "a8" -> , 19: "a9" -> ) [info] org.scalatest.exceptions.TestFailedException: [info] at org.scalatest.Assertions.newAssertionFailedException(Assertions.scala:472) [info] at org.scalatest.Assertions.newAssertionFailedException$(Assertions.scala:471) [info] at org.scalatest.Assertions$.newAssertionFailedException(Assertions.scala:1231) [info] at org.scalatest.Assertions$AssertionsHelper.macroAssert(Assertions.scala:1295) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuite.$anonfun$new$39(StateStoreSuite.scala:414) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuiteBase.tryWithProviderResource(StateStoreSuite.scala:1663) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuite.$anonfun$new$38(StateStoreSuite.scala:394) 18:32:09.694 WARN org.apache.spark.sql.execution.streaming.state.StateStoreSuite: ===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.execution.streaming.state.StateStoreSuite, threads: ForkJoinPool.commonPool-worker-1 (daemon=true) ===== [info] at org.scalatest.enablers.Timed$$anon$1.timeoutAfter(Timed.scala:127) [info] at org.scalatest.concurrent.TimeLimits$.failAfterImpl(TimeLimits.scala:282) [info] at org.scalatest.concurrent.TimeLimits.failAfter(TimeLimits.scala:231) [info] at org.scalatest.concurrent.TimeLimits.failAfter$(TimeLimits.scala:230) [info] at org.apache.spark.SparkFunSuite.failAfter(SparkFunSuite.scala:69) [info] at org.apache.spark.SparkFunSuite.$anonfun$test$2(SparkFunSuite.scala:155) [info] at org.scalatest.OutcomeOf.outcomeOf(OutcomeOf.scala:85) [info] at org.scalatest.OutcomeOf.outcomeOf$(OutcomeOf.scala:83) [info] at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) [info] at org.scalatest.Transformer.apply(Transformer.scala:22) [info] at org.scalatest.Transformer.apply(Transformer.scala:20) [info] at org.scalatest.funsuite.AnyFunSuiteLike$$anon$1.apply(AnyFunSuiteLike.scala:226) [info] at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:227) [info] at org.scalatest.funsuite.AnyFunSuiteLike.invokeWithFixture$1(AnyFunSuiteLike.scala:224) [info] at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTest$1(AnyFunSuiteLike.scala:236) [info] at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) [info] at org.scalatest.funsuite.AnyFunSuiteLike.runTest(AnyFunSuiteLike.scala:236) [info] at org.scalatest.funsuite.AnyFunSuiteLike.runTest$(AnyFunSuiteLike.scala:218) [info] at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterEach$$super$runTest(SparkFunSuite.scala:69) [info] at org.scalatest.BeforeAndAfterEach.runTest(BeforeAndAfterEach.scala:234) [info] at org.scalatest.BeforeAndAfterEach.runTest$(BeforeAndAfterEach.scala:227) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuite.org$scalatest$BeforeAndAfter$$super$runTest(StateStoreSuite.scala:90) [info] at org.scalatest.BeforeAndAfter.runTest(BeforeAndAfter.scala:213) [info] at org.scalatest.BeforeAndAfter.runTest$(BeforeAndAfter.scala:203) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuite.runTest(StateStoreSuite.scala:90) [info] at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTests$1(AnyFunSuiteLike.scala:269) [info] at org.scalatest.SuperEngine.$anonfun$runTestsInBranch$1(Engine.scala:413) [info] at scala.collection.immutable.List.foreach(List.scala:334) [info] at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) [info] at org.scalatest.SuperEngine.runTestsInBranch(Engine.scala:396) [info] at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:475) [info] at org.scalatest.funsuite.AnyFunSuiteLike.runTests(AnyFunSuiteLike.scala:269) [info] at org.scalatest.funsuite.AnyFunSuiteLike.runTests$(AnyFunSuiteLike.scala:268) [info] at org.scalatest.funsuite.AnyFunSuite.runTests(AnyFunSuite.scala:1564) [info] at org.scalatest.Suite.run(Suite.scala:1114) [info] at org.scalatest.Suite.run$(Suite.scala:1096) [info] at org.scalatest.funsuite.AnyFunSuite.org$scalatest$funsuite$AnyFunSuiteLike$$super$run(AnyFunSuite.scala:1564) [info] at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$run$1(AnyFunSuiteLike.scala:273) [info] at org.scalatest.SuperEngine.runImpl(Engine.scala:535) [info] at org.scalatest.funsuite.AnyFunSuiteLike.run(AnyFunSuiteLike.scala:273) [info] at org.scalatest.funsuite.AnyFunSuiteLike.run$(AnyFunSuiteLike.scala:272) [info] at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:69) [info] at org.scalatest.BeforeAndAfterAll.liftedTree1$1(BeforeAndAfterAll.scala:213) [info] at org.scalatest.BeforeAndAfterAll.run(BeforeAndAfterAll.scala:210) [info] at org.scalatest.BeforeAndAfterAll.run$(BeforeAndAfterAll.scala:208) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuite.org$scalatest$BeforeAndAfter$$super$run(StateStoreSuite.scala:90) [info] at org.scalatest.BeforeAndAfter.run(BeforeAndAfter.scala:273) [info] at org.scalatest.BeforeAndAfter.run$(BeforeAndAfter.scala:271) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuite.run(StateStoreSuite.scala:90) [info] at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:321) [info] at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:517) [info] at sbt.ForkMain$Run.lambda$runTest$1(ForkMain.java:414) [info] at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264) [info] at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136) [info] at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635) [info] at java.base/java.lang.Thread.run(Thread.java:840) [info] Run completed in 2 seconds, 4 milliseconds. [info] Total number of tests run: 1 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 0, failed 1, canceled 0, ignored 0, pending 0 [info] *** 1 TEST FAILED *** ``` No Closes #46351 from huanliwang-db/race. Authored-by: Huanli Wang <huanli.wangdatabricks.com> Closes #46415 from huanliwang-db/race-3.5. Authored-by: Huanli Wang <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent 74724d6 commit 07e08c0

File tree

3 files changed

+42
-9
lines changed

3 files changed

+42
-9
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ trait HDFSBackedStateStoreMap {
3131
def remove(key: UnsafeRow): UnsafeRow
3232
def iterator(): Iterator[UnsafeRowPair]
3333
def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair]
34-
def clear(): Unit
3534
}
3635

3736
object HDFSBackedStateStoreMap {
@@ -79,8 +78,6 @@ class NoPrefixHDFSBackedStateStoreMap extends HDFSBackedStateStoreMap {
7978
override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
8079
throw new UnsupportedOperationException("Prefix scan is not supported!")
8180
}
82-
83-
override def clear(): Unit = map.clear()
8481
}
8582

8683
class PrefixScannableHDFSBackedStateStoreMap(
@@ -169,9 +166,4 @@ class PrefixScannableHDFSBackedStateStoreMap(
169166
.iterator
170167
.map { key => unsafeRowPair.withRows(key, map.get(key)) }
171168
}
172-
173-
override def clear(): Unit = {
174-
map.clear()
175-
prefixKeyToKeysMap.clear()
176-
}
177169
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
262262
}
263263

264264
override def close(): Unit = {
265-
loadedMaps.values.asScala.foreach(_.clear())
265+
// Clearing the map resets the TreeMap.root to null, and therefore entries inside the
266+
// `loadedMaps` will be de-referenced and GCed automatically when their reference
267+
// counts become 0.
268+
synchronized { loadedMaps.clear() }
266269
}
267270

268271
override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = {

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,44 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
272272
}
273273
}
274274

275+
test("SPARK-48105: state store unload/close happens during the maintenance") {
276+
tryWithProviderResource(
277+
newStoreProvider(opId = Random.nextInt(), partition = 0, minDeltasForSnapshot = 1)) {
278+
provider =>
279+
val store = provider.getStore(0).asInstanceOf[provider.HDFSBackedStateStore]
280+
val values = (1 to 20)
281+
val keys = values.map(i => ("a" + i))
282+
keys.zip(values).map{case (k, v) => put(store, k, 0, v)}
283+
// commit state store with 20 keys.
284+
store.commit()
285+
// get the state store iterator: mimic the case which the iterator is hold in the
286+
// maintenance thread.
287+
val storeIterator = store.iterator()
288+
289+
// the store iterator should still be valid as the maintenance thread may have already
290+
// hold it and is doing snapshotting even though the state store is unloaded.
291+
val outputKeys = new mutable.ArrayBuffer[String]
292+
val outputValues = new mutable.ArrayBuffer[Int]
293+
var cnt = 0
294+
while (storeIterator.hasNext) {
295+
if (cnt == 10) {
296+
// Mimic the case where the provider is loaded in another executor in the middle of
297+
// iteration. When this happens, the provider will be unloaded and closed in
298+
// current executor.
299+
provider.close()
300+
}
301+
val unsafeRowPair = storeIterator.next()
302+
val (key, _) = keyRowToData(unsafeRowPair.key)
303+
outputKeys.append(key)
304+
outputValues.append(valueRowToData(unsafeRowPair.value))
305+
306+
cnt = cnt + 1
307+
}
308+
assert(keys.sorted === outputKeys.sorted)
309+
assert(values.sorted === outputValues.sorted)
310+
}
311+
}
312+
275313
test("maintenance") {
276314
val conf = new SparkConf()
277315
.setMaster("local")

0 commit comments

Comments
 (0)