diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala index aa393211a1c15..2461e63e98e56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.ThreadUtils @@ -40,8 +41,12 @@ trait AsyncLogPurge extends Logging { private val purgeRunning = new AtomicBoolean(false) + private val purgeOldestRunning = new AtomicBoolean(false) + protected def purge(threshold: Long): Unit + protected def purgeOldest(plan: SparkPlan): Unit + protected lazy val useAsyncPurge: Boolean = sparkSession.conf.get(SQLConf.ASYNC_LOG_PURGE) protected def purgeAsync(batchId: Long): Unit = { @@ -62,6 +67,24 @@ trait AsyncLogPurge extends Logging { } } + protected def purgeOldestAsync(plan: SparkPlan): Unit = { + if (purgeOldestRunning.compareAndSet(false, true)) { + asyncPurgeExecutorService.execute(() => { + try { + purgeOldest(plan) + } catch { + case throwable: Throwable => + logError("Encountered error while performing async log purge", throwable) + errorNotifier.markError(throwable) + } finally { + purgeOldestRunning.set(false) + } + }) + } else { + log.debug("Skipped log purging since there is already one in progress.") + } + } + protected def asyncLogPurgeShutdown(): Unit = { ThreadUtils.shutdown(asyncPurgeExecutorService) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 3fe5aeae5f637..4ce762714e864 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter} +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter, StateSchemaV3File, StateStoreId} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -58,7 +58,7 @@ class IncrementalExecution( val offsetSeqMetadata: OffsetSeqMetadata, val watermarkPropagator: WatermarkPropagator, val isFirstBatch: Boolean) - extends QueryExecution(sparkSession, logicalPlan) with Logging { + extends QueryExecution(sparkSession, logicalPlan) with Logging with AsyncLogPurge { // Modified planner with stateful operations. override val planner: SparkPlanner = new SparkPlanner( @@ -79,6 +79,32 @@ class IncrementalExecution( StreamingTransformWithStateStrategy :: Nil } + // Methods to enable the use of AsyncLogPurge + protected val minLogEntriesToMaintain: Int = + sparkSession.sessionState.conf.minBatchesToRetain + + val errorNotifier: ErrorNotifier = new ErrorNotifier() + + override protected def purge(threshold: Long): Unit = {} + + override protected def purgeOldest(planWithStateOpId: SparkPlan): Unit = { + planWithStateOpId.collect { + case tws: TransformWithStateExec => + val metadataPath = OperatorStateMetadataV2.metadataFilePath(new Path( + checkpointLocation, tws.getStateInfo.operatorId.toString)) + val operatorStateMetadataLog = new OperatorStateMetadataLog(sparkSession, + metadataPath.toString) + val thresholdBatchId = + operatorStateMetadataLog.findThresholdBatchId(minLogEntriesToMaintain) + operatorStateMetadataLog.purge(thresholdBatchId) + val stateSchemaV3File = new StateSchemaV3File( + sparkSession.sessionState.newHadoopConf(), + path = tws.stateSchemaFilePath(Some(StateStoreId.DEFAULT_STORE_NAME)).toString) + stateSchemaV3File.purge(thresholdBatchId) + case _ => + } + } + private lazy val hadoopConf = sparkSession.sessionState.newHadoopConf() private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) @@ -497,6 +523,14 @@ class IncrementalExecution( } } + def purgeMetadataFiles(planWithStateOpId: SparkPlan): Unit = { + if (useAsyncPurge) { + purgeOldestAsync(planWithStateOpId) + } else { + purgeOldest(planWithStateOpId) + } + } + override def apply(plan: SparkPlan): SparkPlan = { val planWithStateOpId = plan transform composedRule // Need to check before write to metadata because we need to detect add operator @@ -508,6 +542,7 @@ class IncrementalExecution( // The rule below doesn't change the plan but can cause the side effect that // metadata/schema is written in the checkpoint directory of stateful operator. planWithStateOpId transform StateSchemaAndOperatorMetadataRule.rule + purgeMetadataFiles(planWithStateOpId) simulateWatermarkPropagation(planWithStateOpId) planWithStateOpId transform WatermarkPropagationRule.rule diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index f636413f7c518..8fb1739f65fab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability} import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl, SupportsTriggerAvailableNow} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming.sources.{WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1} @@ -929,6 +929,10 @@ class MicroBatchExecution( awaitProgressLock.unlock() } } + + override protected def purgeOldest(plan: SparkPlan): Unit = { + + } } object MicroBatchExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala index 09b4b65e3c3a3..19cd96eb59e96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala @@ -61,4 +61,13 @@ class OperatorStateMetadataLog( case "v2" => OperatorStateMetadataUtils.deserialize(2, bufferedReader) } } + + def findThresholdBatchId(minLogEntriesToMaintain: Int): Long = { + val metadataFiles = listBatches + if (metadataFiles.length > minLogEntriesToMaintain) { + metadataFiles.sorted.take(metadataFiles.length - minLogEntriesToMaintain).last + 1 + } else { + -1 + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 3031faa35b2d1..65b435b5c692c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -365,7 +365,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi ttlConfig: TTLConfig): MapState[K, V] = { verifyStateVarOperations("get_map_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, true) + getMapStateSchema(stateName, keyExprEnc, valEncoder, userKeyEnc, true) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[MapState[K, V]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index c42d58ad67eac..0f43256e3b6c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -389,6 +389,11 @@ case class TransformWithStateExec( new OperatorStateMetadataLog(hadoopConf, operatorStateMetadataPath.toString) } + override def metadataFilePath(): Path = { + OperatorStateMetadataV2.metadataFilePath( + new Path(getStateInfo.checkpointLocation, getStateInfo.operatorId.toString)) + } + override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, @@ -397,9 +402,6 @@ case class TransformWithStateExec( val newSchemas = getColFamilySchemas() val schemaFile = new StateSchemaV3File( hadoopConf, stateSchemaDirPath(StateStoreId.DEFAULT_STORE_NAME).toString) - // TODO: [SPARK-48849] Read the schema path from the OperatorStateMetadata file - // and validate it with the new schema - val operatorStateMetadataLog = fetchOperatorStateMetadataLog( hadoopConf, getStateInfo.checkpointLocation, getStateInfo.operatorId) val mostRecentLog = operatorStateMetadataLog.getLatest() @@ -429,17 +431,6 @@ case class TransformWithStateExec( } } - private def stateSchemaDirPath(storeName: String): Path = { - assert(storeName == StateStoreId.DEFAULT_STORE_NAME) - def stateInfo = getStateInfo - val stateCheckpointPath = - new Path(getStateInfo.checkpointLocation, - s"${stateInfo.operatorId.toString}") - - val storeNamePath = new Path(stateCheckpointPath, storeName) - new Path(new Path(storeNamePath, "_metadata"), "schema") - } - /** Metadata of this stateful operator and its states stores. */ override def operatorStateMetadata( stateSchemaPaths: Array[String] = Array.empty): OperatorStateMetadata = { @@ -458,6 +449,17 @@ case class TransformWithStateExec( OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) } + private def stateSchemaDirPath(storeName: String): Path = { + assert(storeName == StateStoreId.DEFAULT_STORE_NAME) + def stateInfo = getStateInfo + val stateCheckpointPath = + new Path(getStateInfo.checkpointLocation, + s"${stateInfo.operatorId.toString}") + + val storeNamePath = new Path(stateCheckpointPath, storeName) + new Path(new Path(storeNamePath, "_metadata"), "schema") + } + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala index 482e802b7d87e..e7e94ae193664 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala @@ -26,6 +26,7 @@ import scala.io.{Source => IOSource} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.streaming.CheckpointFileManager import org.apache.spark.sql.execution.streaming.MetadataVersionUtil.validateVersion @@ -39,7 +40,7 @@ import org.apache.spark.sql.execution.streaming.MetadataVersionUtil.validateVers */ class StateSchemaV3File( hadoopConf: Configuration, - path: String) { + path: String) extends Logging { val metadataPath = new Path(path) @@ -92,8 +93,37 @@ class StateSchemaV3File( throw e } } + + // list all the files in the metadata directory + // sort by the batchId + private[sql] def listFiles(): Seq[Path] = { + fileManager.list(metadataPath).sorted.map(_.getPath).toSeq + } + + private[sql] def listFilesBeforeBatch(batchId: Long): Seq[Path] = { + listFiles().filter { path => + val batchIdInPath = path.getName.split("_").head.toLong + batchIdInPath < batchId + } + } + + /** + * purge schema files that are before thresholdBatchId, exclusive + */ + def purge(thresholdBatchId: Long): Unit = { + if (thresholdBatchId != -1) { + listFilesBeforeBatch(thresholdBatchId).foreach { + schemaFilePath => + fileManager.delete(schemaFilePath) + } + } + } } object StateSchemaV3File { val VERSION = 3 + + private[sql] def getBatchFromPath(path: Path): Long = { + path.getName.split("_").head.toLong + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 31a3ae648c054..fd621fc1161f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -857,7 +857,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest operatorId: Int): StateSchemaV3File = { val hadoopConf = spark.sessionState.newHadoopConf() val stateChkptPath = new Path(checkpointDir, s"state/$operatorId") - val stateSchemaPath = new Path(new Path(stateChkptPath, "_metadata"), "schema") + val storeNamePath = new Path(stateChkptPath, "default") + val stateSchemaPath = new Path(new Path(storeNamePath, "_metadata"), "schema") new StateSchemaV3File(hadoopConf, stateSchemaPath.toString) } @@ -989,6 +990,49 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + + test("transformWithState - verify that metadata logs are purged") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "1") { + withTempDir { chkptDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream, + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "2")), + StopStream, + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(), + StopStream + ) + val operatorStateMetadataLogs = + fetchOperatorStateMetadataLog(chkptDir.getCanonicalPath, 0) + .listBatchesOnDisk + assert(operatorStateMetadataLogs.length == 1) + // Make sure that only the latest batch has the schema file + assert(operatorStateMetadataLogs.head == 2) + + val schemaV3Files = fetchStateSchemaV3File(chkptDir.getCanonicalPath, 0).listFiles() + assert(schemaV3Files.length == 1) + assert(StateSchemaV3File.getBatchFromPath(schemaV3Files.head) == 2) + } + } + } + test("transformWithState - verify OperatorStateMetadataV2 serialization and deserialization" + " works") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->