Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.ThreadUtils

Expand All @@ -40,8 +41,12 @@ trait AsyncLogPurge extends Logging {

private val purgeRunning = new AtomicBoolean(false)

private val purgeOldestRunning = new AtomicBoolean(false)

protected def purge(threshold: Long): Unit

protected def purgeOldest(plan: SparkPlan): Unit

protected lazy val useAsyncPurge: Boolean = sparkSession.conf.get(SQLConf.ASYNC_LOG_PURGE)

protected def purgeAsync(batchId: Long): Unit = {
Expand All @@ -62,6 +67,24 @@ trait AsyncLogPurge extends Logging {
}
}

protected def purgeOldestAsync(plan: SparkPlan): Unit = {
if (purgeOldestRunning.compareAndSet(false, true)) {
asyncPurgeExecutorService.execute(() => {
try {
purgeOldest(plan)
} catch {
case throwable: Throwable =>
logError("Encountered error while performing async log purge", throwable)
errorNotifier.markError(throwable)
} finally {
purgeOldestRunning.set(false)
}
})
} else {
log.debug("Skipped log purging since there is already one in progress.")
}
}

protected def asyncLogPurgeShutdown(): Unit = {
ThreadUtils.shutdown(asyncPurgeExecutorService)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter}
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter, StateSchemaV3File, StateStoreId}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand All @@ -58,7 +58,7 @@ class IncrementalExecution(
val offsetSeqMetadata: OffsetSeqMetadata,
val watermarkPropagator: WatermarkPropagator,
val isFirstBatch: Boolean)
extends QueryExecution(sparkSession, logicalPlan) with Logging {
extends QueryExecution(sparkSession, logicalPlan) with Logging with AsyncLogPurge {

// Modified planner with stateful operations.
override val planner: SparkPlanner = new SparkPlanner(
Expand All @@ -79,6 +79,32 @@ class IncrementalExecution(
StreamingTransformWithStateStrategy :: Nil
}

// Methods to enable the use of AsyncLogPurge
protected val minLogEntriesToMaintain: Int =
sparkSession.sessionState.conf.minBatchesToRetain

val errorNotifier: ErrorNotifier = new ErrorNotifier()

override protected def purge(threshold: Long): Unit = {}

override protected def purgeOldest(planWithStateOpId: SparkPlan): Unit = {
planWithStateOpId.collect {
case tws: TransformWithStateExec =>
val metadataPath = OperatorStateMetadataV2.metadataFilePath(new Path(
checkpointLocation, tws.getStateInfo.operatorId.toString))
val operatorStateMetadataLog = new OperatorStateMetadataLog(sparkSession,
metadataPath.toString)
val thresholdBatchId =
operatorStateMetadataLog.findThresholdBatchId(minLogEntriesToMaintain)
operatorStateMetadataLog.purge(thresholdBatchId)
val stateSchemaV3File = new StateSchemaV3File(
sparkSession.sessionState.newHadoopConf(),
path = tws.stateSchemaFilePath(Some(StateStoreId.DEFAULT_STORE_NAME)).toString)
stateSchemaV3File.purge(thresholdBatchId)
case _ =>
}
}

private lazy val hadoopConf = sparkSession.sessionState.newHadoopConf()

private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
Expand Down Expand Up @@ -497,6 +523,14 @@ class IncrementalExecution(
}
}

def purgeMetadataFiles(planWithStateOpId: SparkPlan): Unit = {
if (useAsyncPurge) {
purgeOldestAsync(planWithStateOpId)
} else {
purgeOldest(planWithStateOpId)
}
}

override def apply(plan: SparkPlan): SparkPlan = {
val planWithStateOpId = plan transform composedRule
// Need to check before write to metadata because we need to detect add operator
Expand All @@ -508,6 +542,7 @@ class IncrementalExecution(
// The rule below doesn't change the plan but can cause the side effect that
// metadata/schema is written in the checkpoint directory of stateful operator.
planWithStateOpId transform StateSchemaAndOperatorMetadataRule.rule
purgeMetadataFiles(planWithStateOpId)

simulateWatermarkPropagation(planWithStateOpId)
planWithStateOpId transform WatermarkPropagationRule.rule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability}
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl, SupportsTriggerAvailableNow}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.streaming.sources.{WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1}
Expand Down Expand Up @@ -929,6 +929,10 @@ class MicroBatchExecution(
awaitProgressLock.unlock()
}
}

override protected def purgeOldest(plan: SparkPlan): Unit = {

}
}

object MicroBatchExecution {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,13 @@ class OperatorStateMetadataLog(
case "v2" => OperatorStateMetadataUtils.deserialize(2, bufferedReader)
}
}

def findThresholdBatchId(minLogEntriesToMaintain: Int): Long = {
val metadataFiles = listBatches
if (metadataFiles.length > minLogEntriesToMaintain) {
metadataFiles.sorted.take(metadataFiles.length - minLogEntriesToMaintain).last + 1
} else {
-1
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
ttlConfig: TTLConfig): MapState[K, V] = {
verifyStateVarOperations("get_map_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, true)
getMapStateSchema(stateName, keyExprEnc, valEncoder, userKeyEnc, true)
columnFamilySchemas.put(stateName, colFamilySchema)
null.asInstanceOf[MapState[K, V]]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,11 @@ case class TransformWithStateExec(
new OperatorStateMetadataLog(hadoopConf, operatorStateMetadataPath.toString)
}

override def metadataFilePath(): Path = {
OperatorStateMetadataV2.metadataFilePath(
new Path(getStateInfo.checkpointLocation, getStateInfo.operatorId.toString))
}

override def validateAndMaybeEvolveStateSchema(
hadoopConf: Configuration,
batchId: Long,
Expand All @@ -397,9 +402,6 @@ case class TransformWithStateExec(
val newSchemas = getColFamilySchemas()
val schemaFile = new StateSchemaV3File(
hadoopConf, stateSchemaDirPath(StateStoreId.DEFAULT_STORE_NAME).toString)
// TODO: [SPARK-48849] Read the schema path from the OperatorStateMetadata file
// and validate it with the new schema

val operatorStateMetadataLog = fetchOperatorStateMetadataLog(
hadoopConf, getStateInfo.checkpointLocation, getStateInfo.operatorId)
val mostRecentLog = operatorStateMetadataLog.getLatest()
Expand Down Expand Up @@ -429,17 +431,6 @@ case class TransformWithStateExec(
}
}

private def stateSchemaDirPath(storeName: String): Path = {
assert(storeName == StateStoreId.DEFAULT_STORE_NAME)
def stateInfo = getStateInfo
val stateCheckpointPath =
new Path(getStateInfo.checkpointLocation,
s"${stateInfo.operatorId.toString}")

val storeNamePath = new Path(stateCheckpointPath, storeName)
new Path(new Path(storeNamePath, "_metadata"), "schema")
}

/** Metadata of this stateful operator and its states stores. */
override def operatorStateMetadata(
stateSchemaPaths: Array[String] = Array.empty): OperatorStateMetadata = {
Expand All @@ -458,6 +449,17 @@ case class TransformWithStateExec(
OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json)
}

private def stateSchemaDirPath(storeName: String): Path = {
assert(storeName == StateStoreId.DEFAULT_STORE_NAME)
def stateInfo = getStateInfo
val stateCheckpointPath =
new Path(getStateInfo.checkpointLocation,
s"${stateInfo.operatorId.toString}")

val storeNamePath = new Path(stateCheckpointPath, storeName)
new Path(new Path(storeNamePath, "_metadata"), "schema")
}

override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.io.{Source => IOSource}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.execution.streaming.MetadataVersionUtil.validateVersion

Expand All @@ -39,7 +40,7 @@ import org.apache.spark.sql.execution.streaming.MetadataVersionUtil.validateVers
*/
class StateSchemaV3File(
hadoopConf: Configuration,
path: String) {
path: String) extends Logging {

val metadataPath = new Path(path)

Expand Down Expand Up @@ -92,8 +93,37 @@ class StateSchemaV3File(
throw e
}
}

// list all the files in the metadata directory
// sort by the batchId
private[sql] def listFiles(): Seq[Path] = {
fileManager.list(metadataPath).sorted.map(_.getPath).toSeq
}

private[sql] def listFilesBeforeBatch(batchId: Long): Seq[Path] = {
listFiles().filter { path =>
val batchIdInPath = path.getName.split("_").head.toLong
batchIdInPath < batchId
}
}

/**
* purge schema files that are before thresholdBatchId, exclusive
*/
def purge(thresholdBatchId: Long): Unit = {
if (thresholdBatchId != -1) {
listFilesBeforeBatch(thresholdBatchId).foreach {
schemaFilePath =>
fileManager.delete(schemaFilePath)
}
}
}
}

object StateSchemaV3File {
val VERSION = 3

private[sql] def getBatchFromPath(path: Path): Long = {
path.getName.split("_").head.toLong
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest
operatorId: Int): StateSchemaV3File = {
val hadoopConf = spark.sessionState.newHadoopConf()
val stateChkptPath = new Path(checkpointDir, s"state/$operatorId")
val stateSchemaPath = new Path(new Path(stateChkptPath, "_metadata"), "schema")
val storeNamePath = new Path(stateChkptPath, "default")
val stateSchemaPath = new Path(new Path(storeNamePath, "_metadata"), "schema")
new StateSchemaV3File(hadoopConf, stateSchemaPath.toString)
}

Expand Down Expand Up @@ -989,6 +990,49 @@ class TransformWithStateSuite extends StateStoreMetricsTest
}
}


test("transformWithState - verify that metadata logs are purged") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString,
SQLConf.MIN_BATCHES_TO_RETAIN.key -> "1") {
withTempDir { chkptDir =>
val inputData = MemoryStream[String]
val result = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessor(),
TimeMode.None(),
OutputMode.Update())

testStream(result, OutputMode.Update())(
StartStream(checkpointLocation = chkptDir.getCanonicalPath),
AddData(inputData, "a"),
CheckNewAnswer(("a", "1")),
StopStream,
StartStream(checkpointLocation = chkptDir.getCanonicalPath),
AddData(inputData, "a"),
CheckNewAnswer(("a", "2")),
StopStream,
StartStream(checkpointLocation = chkptDir.getCanonicalPath),
AddData(inputData, "a"),
CheckNewAnswer(),
StopStream
)
val operatorStateMetadataLogs =
fetchOperatorStateMetadataLog(chkptDir.getCanonicalPath, 0)
.listBatchesOnDisk
assert(operatorStateMetadataLogs.length == 1)
// Make sure that only the latest batch has the schema file
assert(operatorStateMetadataLogs.head == 2)

val schemaV3Files = fetchStateSchemaV3File(chkptDir.getCanonicalPath, 0).listFiles()
assert(schemaV3Files.length == 1)
assert(StateSchemaV3File.getBatchFromPath(schemaV3Files.head) == 2)
}
}
}

test("transformWithState - verify OperatorStateMetadataV2 serialization and deserialization" +
" works") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
Expand Down