diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e2c48e2d8a14c..1c9b047e59748 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -825,6 +825,15 @@ object SQLConf { .intConf .createWithDefault(100) + val MAX_BATCHES_TO_RETAIN_IN_MEMORY = buildConf("spark.sql.streaming.maxBatchesToRetainInMemory") + .internal() + .doc("The maximum number of batches which will be retained in memory to avoid " + + "loading from files. The value adjusts a trade-off between memory usage vs cache miss: " + + "'2' covers both success and direct failure cases, '1' covers only success case, " + + "and '0' covers extreme case - disable cache to maximize memory size of executors.") + .intConf + .createWithDefault(2) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() @@ -1463,6 +1472,8 @@ class SQLConf extends Serializable with Logging { def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN) + def maxBatchesToRetainInMemory: Int = getConf(MAX_BATCHES_TO_RETAIN_IN_MEMORY) + def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 118c82aa75e68..523acef34ca61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ +import java.util import java.util.Locale import scala.collection.JavaConverters._ @@ -203,6 +204,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit this.valueSchema = valueSchema this.storeConf = storeConf this.hadoopConf = hadoopConf + this.numberOfVersionsToRetainInMemory = storeConf.maxVersionsToRetainInMemory fm.mkdirs(baseDir) } @@ -220,7 +222,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def close(): Unit = { - loadedMaps.values.foreach(_.clear()) + loadedMaps.values.asScala.foreach(_.clear()) } override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = { @@ -239,8 +241,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit @volatile private var valueSchema: StructType = _ @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ + @volatile private var numberOfVersionsToRetainInMemory: Int = _ - private lazy val loadedMaps = new mutable.HashMap[Long, MapType] + private lazy val loadedMaps = new util.TreeMap[Long, MapType](Ordering[Long].reverse) private lazy val baseDir = stateStoreId.storeCheckpointLocation() private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) @@ -250,7 +253,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = { synchronized { finalizeDeltaFile(output) - loadedMaps.put(newVersion, map) + putStateIntoStateCacheMap(newVersion, map) } } @@ -260,7 +263,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit */ private[state] def latestIterator(): Iterator[UnsafeRowPair] = synchronized { val versionsInFiles = fetchFiles().map(_.version).toSet - val versionsLoaded = loadedMaps.keySet + val versionsLoaded = loadedMaps.keySet.asScala val allKnownVersions = versionsInFiles ++ versionsLoaded val unsafeRowTuple = new UnsafeRowPair() if (allKnownVersions.nonEmpty) { @@ -270,11 +273,43 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } else Iterator.empty } + /** This method is intended to be only used for unit test(s). DO NOT TOUCH ELEMENTS IN MAP! */ + private[state] def getLoadedMaps(): util.SortedMap[Long, MapType] = synchronized { + // shallow copy as a minimal guard + loadedMaps.clone().asInstanceOf[util.SortedMap[Long, MapType]] + } + + private def putStateIntoStateCacheMap(newVersion: Long, map: MapType): Unit = synchronized { + if (numberOfVersionsToRetainInMemory <= 0) { + if (loadedMaps.size() > 0) loadedMaps.clear() + return + } + + while (loadedMaps.size() > numberOfVersionsToRetainInMemory) { + loadedMaps.remove(loadedMaps.lastKey()) + } + + val size = loadedMaps.size() + if (size == numberOfVersionsToRetainInMemory) { + val versionIdForLastKey = loadedMaps.lastKey() + if (versionIdForLastKey > newVersion) { + // this is the only case which we can avoid putting, because new version will be placed to + // the last key and it should be evicted right away + return + } else if (versionIdForLastKey < newVersion) { + // this case needs removal of the last key before putting new one + loadedMaps.remove(versionIdForLastKey) + } + } + + loadedMaps.put(newVersion, map) + } + /** Load the required version of the map data from the backing files */ private def loadMap(version: Long): MapType = { // Shortcut if the map for this version is already there to avoid a redundant put. - val loadedCurrentVersionMap = synchronized { loadedMaps.get(version) } + val loadedCurrentVersionMap = synchronized { Option(loadedMaps.get(version)) } if (loadedCurrentVersionMap.isDefined) { return loadedCurrentVersionMap.get } @@ -286,7 +321,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val (result, elapsedMs) = Utils.timeTakenMs { val snapshotCurrentVersionMap = readSnapshotFile(version) if (snapshotCurrentVersionMap.isDefined) { - synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } + synchronized { putStateIntoStateCacheMap(version, snapshotCurrentVersionMap.get) } return snapshotCurrentVersionMap.get } @@ -302,7 +337,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit lastAvailableMap = Some(new MapType) } else { lastAvailableMap = - synchronized { loadedMaps.get(lastAvailableVersion) } + synchronized { Option(loadedMaps.get(lastAvailableVersion)) } .orElse(readSnapshotFile(lastAvailableVersion)) } } @@ -314,7 +349,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit updateFromDeltaFile(deltaVersion, resultMap) } - synchronized { loadedMaps.put(version, resultMap) } + synchronized { putStateIntoStateCacheMap(version, resultMap) } resultMap } @@ -506,7 +541,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val lastVersion = files.last.version val deltaFilesForLastVersion = filesForVersion(files, lastVersion).filter(_.isSnapshot == false) - synchronized { loadedMaps.get(lastVersion) } match { + synchronized { Option(loadedMaps.get(lastVersion)) } match { case Some(map) => if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) { val (_, e2) = Utils.timeTakenMs(writeSnapshotFile(lastVersion, map)) @@ -536,10 +571,6 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain if (earliestVersionToRetain > 0) { val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head - synchronized { - val mapsToRemove = loadedMaps.keys.filter(_ < earliestVersionToRetain).toSeq - mapsToRemove.foreach(loadedMaps.remove) - } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) val (_, e2) = Utils.timeTakenMs { filesToDelete.foreach { f => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 765ff076cb467..d145082a39b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -34,6 +34,9 @@ class StateStoreConf(@transient private val sqlConf: SQLConf) /** Minimum versions a State Store implementation should retain to allow rollbacks */ val minVersionsToRetain: Int = sqlConf.minBatchesToRetain + /** Maximum count of versions a State Store implementation should retain in memory */ + val maxVersionsToRetainInMemory: Int = sqlConf.maxBatchesToRetainInMemory + /** * Optional fully qualified name of the subclass of [[StateStoreProvider]] * managing state data. That is, the implementation of the State Store to use. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 73f8705060402..bfeb2b16ff7be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI +import java.util import java.util.UUID import scala.collection.JavaConverters._ @@ -47,6 +48,7 @@ import org.apache.spark.util.Utils class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] with BeforeAndAfter with PrivateMethodTester { type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] + type ProviderMapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow] import StateStoreCoordinatorSuite._ import StateStoreTestsHelper._ @@ -64,21 +66,143 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] require(!StateStore.isMaintenanceRunning) } + def updateVersionTo( + provider: StateStoreProvider, + currentVersion: Int, + targetVersion: Int): Int = { + var newCurrentVersion = currentVersion + for (i <- newCurrentVersion until targetVersion) { + newCurrentVersion = incrementVersion(provider, i) + } + require(newCurrentVersion === targetVersion) + newCurrentVersion + } + + def incrementVersion(provider: StateStoreProvider, currentVersion: Int): Int = { + val store = provider.getStore(currentVersion) + put(store, "a", currentVersion + 1) + store.commit() + currentVersion + 1 + } + + def checkLoadedVersions( + loadedMaps: util.SortedMap[Long, ProviderMapType], + count: Int, + earliestKey: Long, + latestKey: Long): Unit = { + assert(loadedMaps.size() === count) + assert(loadedMaps.firstKey() === earliestKey) + assert(loadedMaps.lastKey() === latestKey) + } + + def checkVersion( + loadedMaps: util.SortedMap[Long, ProviderMapType], + version: Long, + expectedData: Map[String, Int]): Unit = { + + val originValueMap = loadedMaps.get(version).asScala.map { entry => + rowToString(entry._1) -> rowToInt(entry._2) + }.toMap + + assert(originValueMap === expectedData) + } + + test("retaining only two latest versions when MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 2") { + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 2) + + var currentVersion = 0 + + // commit the ver 1 : cache will have one element + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 1)) + var loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1) + checkVersion(loadedMaps, 1, Map("a" -> 1)) + + // commit the ver 2 : cache will have two elements + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 2, earliestKey = 2, latestKey = 1) + checkVersion(loadedMaps, 2, Map("a" -> 2)) + checkVersion(loadedMaps, 1, Map("a" -> 1)) + + // commit the ver 3 : cache has already two elements and adding ver 3 incurs exceeding cache, + // and ver 3 will be added but ver 1 will be evicted + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 3)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 2, earliestKey = 3, latestKey = 2) + checkVersion(loadedMaps, 3, Map("a" -> 3)) + checkVersion(loadedMaps, 2, Map("a" -> 2)) + } + + test("failure after committing with MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 1") { + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 1) + + var currentVersion = 0 + + // commit the ver 1 : cache will have one element + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 1)) + var loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1) + checkVersion(loadedMaps, 1, Map("a" -> 1)) + + // commit the ver 2 : cache has already one elements and adding ver 2 incurs exceeding cache, + // and ver 2 will be added but ver 1 will be evicted + // this fact ensures cache miss will occur when this partition succeeds commit + // but there's a failure afterwards so have to reprocess previous batch + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2) + checkVersion(loadedMaps, 2, Map("a" -> 2)) + + // suppose there has been failure after committing, and it decided to reprocess previous batch + currentVersion = 1 + + // committing to existing version which is committed partially but abandoned globally + val store = provider.getStore(currentVersion) + // negative value to represent reprocessing + put(store, "a", -2) + store.commit() + currentVersion += 1 + + // make sure newly committed version is reflected to the cache (overwritten) + assert(getData(provider) === Set("a" -> -2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2) + checkVersion(loadedMaps, 2, Map("a" -> -2)) + } + + test("no cache data with MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 0") { + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 0) + + var currentVersion = 0 + + // commit the ver 1 : never cached + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 1)) + var loadedMaps = provider.getLoadedMaps() + assert(loadedMaps.size() === 0) + + // commit the ver 2 : never cached + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 2)) + loadedMaps = provider.getLoadedMaps() + assert(loadedMaps.size() === 0) + } + test("snapshotting") { val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) var currentVersion = 0 - def updateVersionTo(targetVersion: Int): Unit = { - for (i <- currentVersion + 1 to targetVersion) { - val store = provider.getStore(currentVersion) - put(store, "a", i) - store.commit() - currentVersion += 1 - } - require(currentVersion === targetVersion) - } - updateVersionTo(2) + currentVersion = updateVersionTo(provider, currentVersion, 2) require(getData(provider) === Set("a" -> 2)) provider.doMaintenance() // should not generate snapshot files assert(getData(provider) === Set("a" -> 2)) @@ -89,7 +213,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } // After version 6, snapshotting should generate one snapshot file - updateVersionTo(6) + currentVersion = updateVersionTo(provider, currentVersion, 6) require(getData(provider) === Set("a" -> 6), "store not updated correctly") provider.doMaintenance() // should generate snapshot files @@ -104,7 +228,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] "snapshotting messed up the data of the final version") // After version 20, snapshotting should generate newer snapshot files - updateVersionTo(20) + currentVersion = updateVersionTo(provider, currentVersion, 20) require(getData(provider) === Set("a" -> 20), "store not updated correctly") provider.doMaintenance() // do snapshot @@ -535,9 +659,11 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] partition: Int, dir: String = newDir(), minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + numOfVersToRetainInMemory: Int = SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get, hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = { val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + sqlConf.setConf(SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY, numOfVersToRetainInMemory) sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) val provider = new HDFSBackedStateStoreProvider() provider.init(