From d9d9d995573070c58c0221d634918f3cc3e9e96a Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 5 Jun 2024 17:27:50 -0700 Subject: [PATCH 1/9] adding OperatorStateMetadataLog --- .../state/metadata/StateMetadataSource.scala | 14 ++- .../streaming/IncrementalExecution.scala | 14 +-- .../streaming/MicroBatchExecution.scala | 27 ++++- .../streaming/OperatorStateMetadataLog.scala | 54 +++++++++ .../state/OperatorStateMetadata.scala | 103 ++++++++++++++++-- .../streaming/statefulOperators.scala | 8 ++ 6 files changed, 192 insertions(+), 28 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 28de21aaf9389..632d91bae6c50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.PATH -import org.apache.spark.sql.execution.streaming.CheckpointFileManager -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1} +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, OperatorStateMetadataLog} +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -55,7 +55,8 @@ case class StateMetadataTableEntry( numPartitions, minBatchId, maxBatchId, - numColsPrefixKey)) + numColsPrefixKey + )) } } @@ -193,8 +194,11 @@ class StateMetadataPartitionReader( val opIds = fileManager .list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted opIds.map { opId => - new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read() - } + val dirLocation = new Path(stateDir, opId.toString) + val metadataFilePath = OperatorStateMetadata.metadataFilePath(dirLocation) + val log = new OperatorStateMetadataLog(SparkSession.active, metadataFilePath.toString) + log.getLatest() + }.filter(_.isDefined).map(_.get._2) } private[state] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index f72e2eb407f84..55aef633a7365 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataWriter} +import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataV1 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -187,17 +187,6 @@ class IncrementalExecution( } } - object WriteStatefulOperatorMetadataRule extends SparkPlanPartialRule { - override val rule: PartialFunction[SparkPlan, SparkPlan] = { - case stateStoreWriter: StateStoreWriter if isFirstBatch => - val metadata = stateStoreWriter.operatorStateMetadata() - val metadataWriter = new OperatorStateMetadataWriter(new Path( - checkpointLocation, stateStoreWriter.getStateInfo.operatorId.toString), hadoopConf) - metadataWriter.write(metadata) - stateStoreWriter - } - } - object StateOpIdRule extends SparkPlanPartialRule { override val rule: PartialFunction[SparkPlan, SparkPlan] = { case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion, @@ -473,7 +462,6 @@ class IncrementalExecution( } // The rule doesn't change the plan but cause the side effect that metadata is written // in the checkpoint directory of stateful operator. - planWithStateOpId transform WriteStatefulOperatorMetadataRule.rule simulateWatermarkPropagation(planWithStateOpId) planWithStateOpId transform WatermarkPropagationRule.rule } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index ef49de5e0857d..ae4333403e495 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability} import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl, SupportsTriggerAvailableNow} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming.sources.{WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1} @@ -88,6 +88,22 @@ class MicroBatchExecution( @volatile protected[sql] var triggerExecutor: TriggerExecutor = _ + private lazy val operatorStateMetadatas: Map[Long, OperatorStateMetadataLog] = { + populateOperatorStateMetadatas(getLatestExecutionContext().executionPlan.executedPlan) + } + + private def populateOperatorStateMetadatas(plan: SparkPlan): + Map[Long, OperatorStateMetadataLog] = { + plan.flatMap { + case s: StateStoreWriter => s.stateInfo.map { info => + val metadataPath = s.metadataFilePath() + info.operatorId -> new OperatorStateMetadataLog(sparkSession, + metadataPath.toString) + } + case _ => Seq.empty + }.toMap + } + protected def getTrigger(): TriggerExecutor = { assert(sources.nonEmpty, "sources should have been retrieved from the plan!") trigger match { @@ -902,6 +918,15 @@ class MicroBatchExecution( if (!commitLog.add(execCtx.batchId, CommitMetadata(watermarkTracker.currentWatermark))) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) } + execCtx.executionPlan.executedPlan.collect { + case s: StateStoreWriter => + val metadata = s.operatorStateMetadata() + val id = metadata.operatorInfo.operatorId + val metadataFile = operatorStateMetadatas(id) + if (!metadataFile.add(execCtx.batchId, metadata)) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) + } + } } committedOffsets ++= execCtx.endOffsets } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala new file mode 100644 index 0000000000000..5b30f85938818 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream} +import java.nio.charset.StandardCharsets +import java.nio.charset.StandardCharsets._ + +import org.apache.hadoop.fs.FSDataOutputStream + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1, OperatorStateMetadataV2} + + +class OperatorStateMetadataLog(sparkSession: SparkSession, path: String) + extends HDFSMetadataLog[OperatorStateMetadata](sparkSession, path) { + override protected def serialize(metadata: OperatorStateMetadata, out: OutputStream): Unit = { + val fsDataOutputStream = out.asInstanceOf[FSDataOutputStream] + fsDataOutputStream.write(s"v${metadata.version}\n".getBytes(StandardCharsets.UTF_8)) + metadata.version match { + case 1 => + OperatorStateMetadataV1.serialize(fsDataOutputStream, metadata) + case 2 => + OperatorStateMetadataV2.serialize(fsDataOutputStream, metadata) + } + } + + override protected def deserialize(in: InputStream): OperatorStateMetadata = { + // called inside a try-finally where the underlying stream is closed in the caller + // create buffered reader from input stream + val bufferedReader = new BufferedReader(new InputStreamReader(in, UTF_8)) + // read first line for version number, in the format "v{version}" + val version = bufferedReader.readLine() + version match { + case "v1" => OperatorStateMetadataV1.deserialize(bufferedReader) + case "v2" => OperatorStateMetadataV2.deserialize(bufferedReader) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index b58c805af9d60..1d016522309a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -25,10 +25,13 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FSDataOutputStream, Path} import org.json4s.{Formats, NoTypeHints} +import org.json4s.JsonAST.JValue import org.json4s.jackson.Serialization +import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} +import org.apache.spark.util.AccumulatorV2 /** * Metadata for a state store instance. @@ -54,6 +57,15 @@ case class OperatorInfoV1(operatorId: Long, operatorName: String) extends Operat trait OperatorStateMetadata { def version: Int + + def operatorInfo: OperatorInfo + + def stateStoreInfo: Array[StateStoreMetadataV1] +} + +object OperatorStateMetadata { + def metadataFilePath(stateCheckpointPath: Path): Path = + new Path(new Path(stateCheckpointPath, "_metadata"), "metadata") } case class OperatorStateMetadataV1( @@ -62,6 +74,56 @@ case class OperatorStateMetadataV1( override def version: Int = 1 } +/** + * Accumulator to store arbitrary Operator properties. + * This accumulator is used to store the properties of an operator that are not + * available on the driver at the time of planning, and will only be known from + * the executor side. + */ +class OperatorProperties(initValue: Map[String, JValue] = Map.empty) + extends AccumulatorV2[Map[String, JValue], Map[String, JValue]] { + + private var _value: Map[String, JValue] = initValue + + override def isZero: Boolean = _value.isEmpty + + override def copy(): AccumulatorV2[Map[String, JValue], Map[String, JValue]] = { + val newAcc = new OperatorProperties + newAcc._value = _value + newAcc + } + + override def reset(): Unit = _value = Map.empty[String, JValue] + + override def add(v: Map[String, JValue]): Unit = _value ++= v + + override def merge(other: AccumulatorV2[Map[String, JValue], Map[String, JValue]]): Unit = { + _value ++= other.value + } + + override def value: Map[String, JValue] = _value +} + +object OperatorProperties { + def create( + sc: SparkContext, + name: String, + initValue: Map[String, JValue] = Map.empty): OperatorProperties = { + val acc = new OperatorProperties(initValue) + acc.register(sc, name = Some(name)) + acc + } +} + +// operatorProperties is an arbitrary JSON formatted string that contains +// any properties that we would want to store for a particular operator. +case class OperatorStateMetadataV2( + operatorInfo: OperatorInfoV1, + stateStoreInfo: Array[StateStoreMetadataV1], + operatorPropertiesJson: String) extends OperatorStateMetadata { + override def version: Int = 2 +} + object OperatorStateMetadataV1 { private implicit val formats: Formats = Serialization.formats(NoTypeHints) @@ -70,9 +132,6 @@ object OperatorStateMetadataV1 { private implicit val manifest = Manifest .classType[OperatorStateMetadataV1](implicitly[ClassTag[OperatorStateMetadataV1]].runtimeClass) - def metadataFilePath(stateCheckpointPath: Path): Path = - new Path(new Path(stateCheckpointPath, "_metadata"), "metadata") - def deserialize(in: BufferedReader): OperatorStateMetadata = { Serialization.read[OperatorStateMetadataV1](in) } @@ -84,13 +143,31 @@ object OperatorStateMetadataV1 { } } +object OperatorStateMetadataV2 { + private implicit val formats: Formats = Serialization.formats(NoTypeHints) + + @scala.annotation.nowarn + private implicit val manifest = Manifest + .classType[OperatorStateMetadataV2](implicitly[ClassTag[OperatorStateMetadataV2]].runtimeClass) + + def deserialize(in: BufferedReader): OperatorStateMetadata = { + Serialization.read[OperatorStateMetadataV2](in) + } + + def serialize( + out: FSDataOutputStream, + operatorStateMetadata: OperatorStateMetadata): Unit = { + Serialization.write(operatorStateMetadata.asInstanceOf[OperatorStateMetadataV2], out) + } +} + /** * Write OperatorStateMetadata into the state checkpoint directory. */ class OperatorStateMetadataWriter(stateCheckpointPath: Path, hadoopConf: Configuration) extends Logging { - private val metadataFilePath = OperatorStateMetadataV1.metadataFilePath(stateCheckpointPath) + private val metadataFilePath = OperatorStateMetadata.metadataFilePath(stateCheckpointPath) private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) @@ -101,7 +178,12 @@ class OperatorStateMetadataWriter(stateCheckpointPath: Path, hadoopConf: Configu val outputStream = fm.createAtomic(metadataFilePath, overwriteIfPossible = false) try { outputStream.write(s"v${operatorMetadata.version}\n".getBytes(StandardCharsets.UTF_8)) - OperatorStateMetadataV1.serialize(outputStream, operatorMetadata) + operatorMetadata.version match { + case 1 => + OperatorStateMetadataV1.serialize(outputStream, operatorMetadata) + case 2 => + OperatorStateMetadataV2.serialize(outputStream, operatorMetadata) + } outputStream.close() } catch { case e: Throwable => @@ -117,7 +199,7 @@ class OperatorStateMetadataWriter(stateCheckpointPath: Path, hadoopConf: Configu */ class OperatorStateMetadataReader(stateCheckpointPath: Path, hadoopConf: Configuration) { - private val metadataFilePath = OperatorStateMetadataV1.metadataFilePath(stateCheckpointPath) + private val metadataFilePath = OperatorStateMetadata.metadataFilePath(stateCheckpointPath) private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) @@ -127,9 +209,12 @@ class OperatorStateMetadataReader(stateCheckpointPath: Path, hadoopConf: Configu new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) try { val versionStr = inputReader.readLine() - val version = MetadataVersionUtil.validateVersion(versionStr, 1) - assert(version == 1) - OperatorStateMetadataV1.deserialize(inputReader) + val version = MetadataVersionUtil.validateVersion(versionStr, 2) + assert(version == 1 || version == 2) + version match { + case 1 => OperatorStateMetadataV1.deserialize(inputReader) + case 2 => OperatorStateMetadataV2.deserialize(inputReader) + } } finally { inputStream.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 9add574e01fc5..9d6b4cb55b02a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -23,6 +23,8 @@ import java.util.concurrent.TimeUnit._ import scala.collection.mutable import scala.jdk.CollectionConverters._ +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.AnalysisException @@ -70,6 +72,12 @@ trait StatefulOperator extends SparkPlan { throw new IllegalStateException("State location not present for execution") } } + + def metadataFilePath(): Path = { + val stateCheckpointPath = + new Path(getStateInfo.checkpointLocation, getStateInfo.operatorId.toString) + new Path(new Path(stateCheckpointPath, "_metadata"), "metadata") + } } /** From 838c239c02becdaced126014f18abc3dce9a8000 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 6 Jun 2024 09:13:58 -0700 Subject: [PATCH 2/9] moving it to streamexecution --- .../streaming/MicroBatchExecution.scala | 20 ++----------------- .../execution/streaming/StreamExecution.scala | 17 ++++++++++++++++ 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index ae4333403e495..acfaeca10c5ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability} import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl, SupportsTriggerAvailableNow} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming.sources.{WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1} @@ -88,22 +88,6 @@ class MicroBatchExecution( @volatile protected[sql] var triggerExecutor: TriggerExecutor = _ - private lazy val operatorStateMetadatas: Map[Long, OperatorStateMetadataLog] = { - populateOperatorStateMetadatas(getLatestExecutionContext().executionPlan.executedPlan) - } - - private def populateOperatorStateMetadatas(plan: SparkPlan): - Map[Long, OperatorStateMetadataLog] = { - plan.flatMap { - case s: StateStoreWriter => s.stateInfo.map { info => - val metadataPath = s.metadataFilePath() - info.operatorId -> new OperatorStateMetadataLog(sparkSession, - metadataPath.toString) - } - case _ => Seq.empty - }.toMap - } - protected def getTrigger(): TriggerExecutor = { assert(sources.nonEmpty, "sources should have been retrieved from the plan!") trigger match { @@ -922,7 +906,7 @@ class MicroBatchExecution( case s: StateStoreWriter => val metadata = s.operatorStateMetadata() val id = metadata.operatorInfo.operatorId - val metadataFile = operatorStateMetadatas(id) + val metadataFile = operatorStateMetadataLogs(id) if (!metadataFile.add(execCtx.batchId, metadata)) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 420deda3e0175..3a4bfb3b0edce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream} import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.execution.streaming.sources.ForeachBatchUserFuncException import org.apache.spark.sql.internal.SQLConf @@ -239,6 +240,22 @@ abstract class StreamExecution( */ val commitLog = new CommitLog(sparkSession, checkpointFile("commits")) + lazy val operatorStateMetadataLogs: Map[Long, OperatorStateMetadataLog] = { + populateOperatorStateMetadatas(getLatestExecutionContext().executionPlan.executedPlan) + } + + private def populateOperatorStateMetadatas( + plan: SparkPlan): Map[Long, OperatorStateMetadataLog] = { + plan.flatMap { + case s: StateStoreWriter => s.stateInfo.map { info => + val metadataPath = s.metadataFilePath() + info.operatorId -> new OperatorStateMetadataLog(sparkSession, + metadataPath.toString) + } + case _ => Seq.empty + }.toMap + } + /** Whether all fields of the query have been initialized */ private def isInitialized: Boolean = state.get != INITIALIZING From 67fe4bf32701f24f208dd1d2b1ab0945f416cfe7 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 6 Jun 2024 09:15:46 -0700 Subject: [PATCH 3/9] adding purging logic --- .../apache/spark/sql/execution/streaming/StreamExecution.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3a4bfb3b0edce..fce0bb44c88ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -694,6 +694,7 @@ abstract class StreamExecution( logDebug(s"Purging metadata at threshold=$threshold") offsetLog.purge(threshold) commitLog.purge(threshold) + operatorStateMetadataLogs.foreach(_._2.purge(threshold)) } } From d24677b3e7d49d45e858437c5893581394e75ef1 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 7 Jun 2024 10:17:08 -0700 Subject: [PATCH 4/9] it writes operatorStateMetadata out --- .../streaming/StateVariableInfo.scala | 71 +++++++++++++++++++ .../StatefulProcessorHandleImpl.scala | 11 ++- .../streaming/TransformWithStateExec.scala | 43 +++++++++++ .../streaming/statefulOperators.scala | 2 + .../streaming/TransformWithStateSuite.scala | 70 ++++++++++++++++++ 5 files changed, 195 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableInfo.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableInfo.scala new file mode 100644 index 0000000000000..e381984b4672d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableInfo.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.json4s.{DefaultFormats, Formats, JValue} +import org.json4s.JsonAST.{JBool, JString} +import org.json4s.JsonDSL._ + +// Enum to store the types of state variables we support +sealed trait StateVariableType + +case object ValueState extends StateVariableType +case object ListState extends StateVariableType +case object MapState extends StateVariableType + +// This object is used to convert the state type from string to the corresponding enum +object StateVariableType { + def withName(name: String): StateVariableType = name match { + case "ValueState" => ValueState + case "ListState" => ListState + case "MapState" => MapState + case _ => throw new IllegalArgumentException(s"Unknown state type: $name") + } +} + +// This class is used to store the information about a state variable. +// It is stored in operatorProperties for the TransformWithStateExec operator +// to be able to validate that the State Variables are the same across restarts. +class StateVariableInfo( + val stateName: String, + val stateType: StateVariableType, + val isTtlEnabled: Boolean + ) { + def jsonValue: JValue = { + ("stateName" -> JString(stateName)) ~ + ("stateType" -> JString(stateType.toString)) ~ + ("isTtlEnabled" -> JBool(isTtlEnabled)) + } +} + +// This object is used to convert the state variable information +// from JSON to a list of StateVariableInfo +object StateVariableInfo { + implicit val formats: Formats = DefaultFormats + def fromJson(json: Any): List[StateVariableInfo] = { + assert(json.isInstanceOf[List[_]], s"Expected List but got ${json.getClass}") + val stateVariables = json.asInstanceOf[List[Map[String, Any]]] + // Extract each JValue to StateVariableInfo + stateVariables.map { stateVariable => + new StateVariableInfo( + stateVariable("stateName").asInstanceOf[String], + StateVariableType.withName(stateVariable("stateType").asInstanceOf[String]), + stateVariable("isTtlEnabled").asInstanceOf[Boolean] + ) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index dcc77e94de280..fd7b7bacb4f7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -94,6 +94,9 @@ class StatefulProcessorHandleImpl( */ private[sql] val ttlStates: util.List[TTLState] = new util.ArrayList[TTLState]() + private[sql] val stateVariables: util.List[StateVariableInfo] = + new util.ArrayList[StateVariableInfo]() + private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" private def buildQueryInfo(): QueryInfo = { @@ -131,6 +134,7 @@ class StatefulProcessorHandleImpl( verifyStateVarOperations("get_value_state") incrementMetric("numValueStateVars") val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) + stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) resultState } @@ -146,6 +150,7 @@ class StatefulProcessorHandleImpl( keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numValueStateWithTTLVars") ttlStates.add(valueStateWithTTL) + stateVariables.add(new StateVariableInfo(stateName, ValueState, true)) valueStateWithTTL } @@ -242,6 +247,7 @@ class StatefulProcessorHandleImpl( verifyStateVarOperations("get_list_state") incrementMetric("numListStateVars") val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) + stateVariables.add(new StateVariableInfo(stateName, ListState, false)) resultState } @@ -273,7 +279,7 @@ class StatefulProcessorHandleImpl( keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numListStateWithTTLVars") ttlStates.add(listStateWithTTL) - + stateVariables.add(new StateVariableInfo(stateName, ListState, true)) listStateWithTTL } @@ -284,6 +290,7 @@ class StatefulProcessorHandleImpl( verifyStateVarOperations("get_map_state") incrementMetric("numMapStateVars") val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) + stateVariables.add(new StateVariableInfo(stateName, MapState, false)) resultState } @@ -300,7 +307,7 @@ class StatefulProcessorHandleImpl( valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numMapStateWithTTLVars") ttlStates.add(mapStateWithTTL) - + stateVariables.add(new StateVariableInfo(stateName, MapState, true)) mapStateWithTTL } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index fbd062acfb5dd..28aa08b963d7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -19,6 +19,14 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID import java.util.concurrent.TimeUnit.NANOSECONDS +import scala.jdk.CollectionConverters.CollectionHasAsScala + +import org.json4s.{DefaultFormats, JArray, JString} +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -75,8 +83,32 @@ case class TransformWithStateExec( initialState: SparkPlan) extends BinaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec { + val operatorProperties: OperatorProperties = + OperatorProperties.create( + sparkContext, + "colFamilyMetadata" + ) + + override def operatorStateMetadataVersion: Int = 2 + override def shortName: String = "transformWithStateExec" + + /** Metadata of this stateful operator and its states stores. */ + override def operatorStateMetadata(): OperatorStateMetadata = { + val info = getStateInfo + val operatorInfo = OperatorInfoV1(info.operatorId, shortName) + val stateStoreInfo = + Array(StateStoreMetadataV1(StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions)) + + val operatorPropertiesJson: JValue = ("timeMode" -> JString(timeMode.toString)) ~ + ("outputMode" -> JString(outputMode.toString)) ~ + ("stateVariables" -> operatorProperties.value.get("stateVariables")) + + val json = compact(render(operatorPropertiesJson)) + OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) + } + override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { if (timeMode == ProcessingTime) { // TODO: check if we can return true only if actual timers are registered, or there is @@ -306,6 +338,9 @@ case class TransformWithStateExec( store.abort() } } + operatorProperties.add(Map + ("stateVariables" -> JArray(processorHandle.stateVariables. + asScala.map(_.jsonValue).toList))) setStoreMetrics(store) setOperatorMetrics() statefulProcessor.close() @@ -561,6 +596,14 @@ object TransformWithStateExec { initialStateDeserializer, initialState) } + + def deserializeOperatorProperties(json: String): Map[String, Any] = { + val parsedJson = JsonMethods.parse(json) + + implicit val formats = DefaultFormats + val deserializedMap: Map[String, Any] = parsedJson.extract[Map[String, Any]] + deserializedMap + } } // scalastyle:on argcount diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 9d6b4cb55b02a..f7c6ffb8fdc47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -140,6 +140,8 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp */ def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = Some(inputWatermarkMs) + def operatorStateMetadataVersion: Int = 1 + override lazy val metrics = statefulOperatorCustomMetrics ++ Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numRowsDroppedByWatermark" -> SQLMetrics.createMetric(sparkContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 0057af44d3e37..e283ba5c11f34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkRuntimeException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Encoders} import org.apache.spark.sql.catalyst.util.stringToFile +import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateStoreMultipleColumnFamiliesNotSupportedException} import org.apache.spark.sql.functions.timestamp_seconds @@ -448,6 +449,75 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + test("verify that operatorProperties contain all stateVariables") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + withTempDir { chkptDir => + val clock = new StreamManualClock + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream( + Trigger.ProcessingTime("1 second"), + triggerClock = clock, + checkpointLocation = chkptDir.getCanonicalPath + ), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + Execute { q => + assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) + assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") === 1) + }, + AddData(inputData, "b"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("b", "1")), + + AddData(inputData, "b"), + AdvanceManualClock(10 * 1000), + CheckNewAnswer(("a", "-1"), ("b", "2")), + + AddData(inputData, "b"), + AddData(inputData, "c"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("c", "1")), // should remove 'b' as count reaches 3 + + AddData(inputData, "d"), + AdvanceManualClock(10 * 1000), + CheckNewAnswer(("c", "-1"), ("d", "1")), + StopStream + ) + + val df = spark.read + .format("state-metadata") + .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) + .load() + + val propsString = df.select("operatorProperties"). + collect().head.getString(0) + + val map = TransformWithStateExec. + deserializeOperatorProperties(propsString) + assert(map("timeMode") === "ProcessingTime") + assert(map("outputMode") === "Update") + + val stateVariableInfos = StateVariableInfo.fromJson( + map("stateVariables")) + assert(stateVariableInfos.size === 1) + val stateVariableInfo = stateVariableInfos.head + assert(stateVariableInfo.stateName === "countState") + assert(stateVariableInfo.isTtlEnabled === false) + assert(stateVariableInfo.stateType === ValueState) + } + } + } + test("transformWithState - streaming with rocksdb and processing time timer " + "and multiple timers should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> From 9f5b8f4ef6a6e046089d720da38691066c7ca929 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 7 Jun 2024 14:04:17 -0700 Subject: [PATCH 5/9] new constructor --- .../state/metadata/StateMetadataSource.scala | 21 ++++++++++++------- .../execution/streaming/HDFSMetadataLog.scala | 20 +++++++++++++----- .../streaming/IncrementalExecution.scala | 12 +++++------ .../streaming/OperatorStateMetadataLog.scala | 19 +++++++++++++++-- 4 files changed, 52 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 632d91bae6c50..44a55aa1d8d1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionRead import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.PATH import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, OperatorStateMetadataLog} -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1} +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1, OperatorStateMetadataV2} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -46,6 +46,7 @@ case class StateMetadataTableEntry( numPartitions: Int, minBatchId: Long, maxBatchId: Long, + operatorPropertiesJson: String, numColsPrefixKey: Int) { def toRow(): InternalRow = { new GenericInternalRow( @@ -55,6 +56,7 @@ case class StateMetadataTableEntry( numPartitions, minBatchId, maxBatchId, + UTF8String.fromString(operatorPropertiesJson), numColsPrefixKey )) } @@ -69,6 +71,7 @@ object StateMetadataTableEntry { .add("numPartitions", IntegerType) .add("minBatchId", LongType) .add("maxBatchId", LongType) + .add("operatorProperties", StringType) } } @@ -196,22 +199,26 @@ class StateMetadataPartitionReader( opIds.map { opId => val dirLocation = new Path(stateDir, opId.toString) val metadataFilePath = OperatorStateMetadata.metadataFilePath(dirLocation) - val log = new OperatorStateMetadataLog(SparkSession.active, metadataFilePath.toString) + val log = new OperatorStateMetadataLog(hadoopConf, metadataFilePath.toString) log.getLatest() }.filter(_.isDefined).map(_.get._2) } private[state] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { allOperatorStateMetadata.flatMap { operatorStateMetadata => - require(operatorStateMetadata.version == 1) - val operatorStateMetadataV1 = operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1] - operatorStateMetadataV1.stateStoreInfo.map { stateStoreMetadata => - StateMetadataTableEntry(operatorStateMetadataV1.operatorInfo.operatorId, - operatorStateMetadataV1.operatorInfo.operatorName, + require(operatorStateMetadata.version == 1 || operatorStateMetadata.version == 2) + val operatorProperties = operatorStateMetadata match { + case _: OperatorStateMetadataV1 => "" + case v2: OperatorStateMetadataV2 => v2.operatorPropertiesJson + } + operatorStateMetadata.stateStoreInfo.map { stateStoreMetadata => + StateMetadataTableEntry(operatorStateMetadata.operatorInfo.operatorId, + operatorStateMetadata.operatorInfo.operatorName, stateStoreMetadata.storeName, stateStoreMetadata.numPartitions, if (batchIds.nonEmpty) batchIds.head else -1, if (batchIds.nonEmpty) batchIds.last else -1, + operatorProperties, stateStoreMetadata.numColsPrefixKey ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 251cc16acdf43..c08659a98b7a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.json4s.{Formats, NoTypeHints} import org.json4s.jackson.Serialization @@ -48,9 +49,21 @@ import org.apache.spark.util.ArrayImplicits._ * Note: [[HDFSMetadataLog]] doesn't support S3-like file systems as they don't guarantee listing * files in a directory always shows the latest files. */ -class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: String) +class HDFSMetadataLog[T <: AnyRef : ClassTag]( + hadoopConf: Configuration, + path: String, + val metadataCacheEnabled: Boolean = false) extends MetadataLog[T] with Logging { + def this(sparkSession: SparkSession, path: String) = { + this( + sparkSession.sessionState.newHadoopConf(), + path, + metadataCacheEnabled = sparkSession.sessionState.conf.getConf( + SQLConf.STREAMING_METADATA_CACHE_ENABLED) + ) + } + private implicit val formats: Formats = Serialization.formats(NoTypeHints) /** Needed to serialize type T into JSON when using Jackson */ @@ -64,15 +77,12 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: val metadataPath = new Path(path) protected val fileManager = - CheckpointFileManager.create(metadataPath, sparkSession.sessionState.newHadoopConf()) + CheckpointFileManager.create(metadataPath, hadoopConf) if (!fileManager.exists(metadataPath)) { fileManager.mkdirs(metadataPath) } - protected val metadataCacheEnabled: Boolean - = sparkSession.sessionState.conf.getConf(SQLConf.STREAMING_METADATA_CACHE_ENABLED) - /** * Cache the latest two batches. [[StreamExecution]] usually just accesses the latest two batches * when committing offsets, this cache will save some file system operations. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 55aef633a7365..42015a5bd29ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 -import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataV1 +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -422,11 +422,11 @@ class IncrementalExecution( new Path(checkpointLocation).getParent.toString, new SerializableConfiguration(hadoopConf)) val opMetadataList = reader.allOperatorStateMetadata - ret = opMetadataList.map { operatorMetadata => - val metadataInfoV1 = operatorMetadata - .asInstanceOf[OperatorStateMetadataV1] - .operatorInfo - metadataInfoV1.operatorId -> metadataInfoV1.operatorName + ret = opMetadataList.map { + case OperatorStateMetadataV1(operatorInfo, _) => + operatorInfo.operatorId -> operatorInfo.operatorName + case OperatorStateMetadataV2(operatorInfo, _, _) => + operatorInfo.operatorId -> operatorInfo.operatorName }.toMap } catch { case e: Exception => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala index 5b30f85938818..f77875279384f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala @@ -21,14 +21,29 @@ import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream} import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets._ +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FSDataOutputStream import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1, OperatorStateMetadataV2} +import org.apache.spark.sql.internal.SQLConf -class OperatorStateMetadataLog(sparkSession: SparkSession, path: String) - extends HDFSMetadataLog[OperatorStateMetadata](sparkSession, path) { +class OperatorStateMetadataLog( + hadoopConf: Configuration, + path: String, + metadataCacheEnabled: Boolean = false) + extends HDFSMetadataLog[OperatorStateMetadata](hadoopConf, path, metadataCacheEnabled) { + + def this(sparkSession: SparkSession, path: String) = { + this( + sparkSession.sessionState.newHadoopConf(), + path, + metadataCacheEnabled = sparkSession.sessionState.conf.getConf( + SQLConf.STREAMING_METADATA_CACHE_ENABLED) + ) + } + override protected def serialize(metadata: OperatorStateMetadata, out: OutputStream): Unit = { val fsDataOutputStream = out.asInstanceOf[FSDataOutputStream] fsDataOutputStream.write(s"v${metadata.version}\n".getBytes(StandardCharsets.UTF_8)) From be7f2d0418a0dee6802138702c01f4caf90b8fd1 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 10 Jun 2024 09:53:57 -0700 Subject: [PATCH 6/9] only writing out metadatas for new runId --- .../streaming/MicroBatchExecution.scala | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index acfaeca10c5ae..57a321d29c0db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -902,14 +902,19 @@ class MicroBatchExecution( if (!commitLog.add(execCtx.batchId, CommitMetadata(watermarkTracker.currentWatermark))) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) } - execCtx.executionPlan.executedPlan.collect { - case s: StateStoreWriter => - val metadata = s.operatorStateMetadata() - val id = metadata.operatorInfo.operatorId - val metadataFile = operatorStateMetadataLogs(id) - if (!metadataFile.add(execCtx.batchId, metadata)) { - throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) - } + val shouldWriteMetadatas = execCtx.previousContext.isEmpty || + execCtx.previousContext.get.executionPlan.runId != execCtx.executionPlan.runId + if (shouldWriteMetadatas) { + logError("writing out metadatas") + execCtx.executionPlan.executedPlan.collect { + case s: StateStoreWriter => + val metadata = s.operatorStateMetadata() + val id = metadata.operatorInfo.operatorId + val metadataFile = operatorStateMetadataLogs(id) + if (!metadataFile.add(execCtx.batchId, metadata)) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) + } + } } } committedOffsets ++= execCtx.endOffsets From 375b9805d0f43cbf065946a7b7725f3e4adec2a6 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 10 Jun 2024 09:56:05 -0700 Subject: [PATCH 7/9] using match statement instead --- .../sql/execution/streaming/MicroBatchExecution.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 57a321d29c0db..6404681e5230a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -902,10 +902,13 @@ class MicroBatchExecution( if (!commitLog.add(execCtx.batchId, CommitMetadata(watermarkTracker.currentWatermark))) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) } - val shouldWriteMetadatas = execCtx.previousContext.isEmpty || - execCtx.previousContext.get.executionPlan.runId != execCtx.executionPlan.runId + val shouldWriteMetadatas = execCtx.previousContext match { + case Some(prevCtx) + if prevCtx.executionPlan.runId == execCtx.executionPlan.runId => + false + case _ => true + } if (shouldWriteMetadatas) { - logError("writing out metadatas") execCtx.executionPlan.executedPlan.collect { case s: StateStoreWriter => val metadata = s.operatorStateMetadata() From c058c76c4b7fe39500b29de2714895edc412c6ce Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 10 Jun 2024 10:31:33 -0700 Subject: [PATCH 8/9] adding purge methods --- .../sql/execution/streaming/AsyncLogPurge.scala | 3 +++ .../sql/execution/streaming/HDFSMetadataLog.scala | 13 +++++++++++++ .../execution/streaming/MicroBatchExecution.scala | 1 + .../sql/execution/streaming/StreamExecution.scala | 6 +++++- 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala index aa393211a1c15..cd32520a70308 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala @@ -42,6 +42,8 @@ trait AsyncLogPurge extends Logging { protected def purge(threshold: Long): Unit + protected def purgeOldest(): Unit + protected lazy val useAsyncPurge: Boolean = sparkSession.conf.get(SQLConf.ASYNC_LOG_PURGE) protected def purgeAsync(batchId: Long): Unit = { @@ -49,6 +51,7 @@ trait AsyncLogPurge extends Logging { asyncPurgeExecutorService.execute(() => { try { purge(batchId - minLogEntriesToMaintain) + purgeOldest() } catch { case throwable: Throwable => logError("Encountered error while performing async log purge", throwable) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index c08659a98b7a2..bb59732db5602 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -335,6 +335,19 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag]( } } + def purgeOldest(minEntriesToMaintain: Int): Unit = { + val batchIds = listBatches.sorted + if (batchIds.length > minEntriesToMaintain) { + val filesToDelete = batchIds.take(batchIds.length - minEntriesToMaintain) + filesToDelete.foreach { batchId => + val path = batchIdToPath(batchId) + fileManager.delete(path) + if (metadataCacheEnabled) batchCache.remove(batchId) + logTrace(s"Removed metadata log file: $path") + } + } + } + /** List the available batches on file system. */ protected def listBatches: Array[Long] = { val batchIds = fileManager.list(metadataPath, batchFilesFilter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6404681e5230a..f85adf8c34363 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -663,6 +663,7 @@ class MicroBatchExecution( purgeAsync(execCtx.batchId) } else { purge(execCtx.batchId - minLogEntriesToMaintain) + purgeOldest() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index fce0bb44c88ae..605f536122f92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -694,7 +694,11 @@ abstract class StreamExecution( logDebug(s"Purging metadata at threshold=$threshold") offsetLog.purge(threshold) commitLog.purge(threshold) - operatorStateMetadataLogs.foreach(_._2.purge(threshold)) + } + + protected def purgeOldest(): Unit = { + operatorStateMetadataLogs.foreach( + _._2.purgeOldest(minLogEntriesToMaintain)) } } From f0ccc6443335f3d06e0f93ab56341f010dbb361b Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 13 Jun 2024 14:47:22 -0700 Subject: [PATCH 9/9] removing accumulator --- .../StatefulProcessorHandleImpl.scala | 124 +++++++++++------- .../streaming/TransformWithStateExec.scala | 29 ++-- .../state/OperatorStateMetadata.scala | 41 ------ .../streaming/state/ListStateSuite.scala | 14 +- .../streaming/state/MapStateSuite.scala | 12 +- .../state/StatefulProcessorHandleSuite.scala | 20 +-- .../streaming/state/ValueStateSuite.scala | 22 ++-- 7 files changed, 126 insertions(+), 136 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index fd7b7bacb4f7b..e1d578fb2e5ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -48,7 +48,7 @@ object ImplicitGroupingKeyTracker { */ object StatefulProcessorHandleState extends Enumeration { type StatefulProcessorHandleState = Value - val CREATED, INITIALIZED, DATA_PROCESSED, TIMER_PROCESSED, CLOSED = Value + val CREATED, PRE_INIT, INITIALIZED, DATA_PROCESSED, TIMER_PROCESSED, CLOSED = Value } class QueryInfoImpl( @@ -70,7 +70,8 @@ class QueryInfoImpl( /** * Class that provides a concrete implementation of a StatefulProcessorHandle. Note that we keep * track of valid transitions as various functions are invoked to track object lifecycle. - * @param store - instance of state store + * @param store - instance of state store - if this processorHandle is being created + * on the driver, the Store will be None. * @param runId - unique id for the current run * @param keyEncoder - encoder for the key * @param isStreaming - defines whether the query is streaming or batch @@ -78,7 +79,7 @@ class QueryInfoImpl( * @param metrics - metrics to be updated as part of stateful processing */ class StatefulProcessorHandleImpl( - store: StateStore, + store: Option[StateStore], runId: UUID, keyEncoder: ExpressionEncoder[Any], timeMode: TimeMode, @@ -131,34 +132,42 @@ class StatefulProcessorHandleImpl( override def getValueState[T]( stateName: String, valEncoder: Encoder[T]): ValueState[T] = { - verifyStateVarOperations("get_value_state") - incrementMetric("numValueStateVars") - val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) - stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) - resultState + store match { + case Some(store) => + verifyStateVarOperations("get_value_state") + incrementMetric("numValueStateVars") + new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) + case None => + stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) + null + } } override def getValueState[T]( stateName: String, valEncoder: Encoder[T], ttlConfig: TTLConfig): ValueState[T] = { - verifyStateVarOperations("get_value_state") - validateTTLConfig(ttlConfig, stateName) - - assert(batchTimestampMs.isDefined) - val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName, - keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) - incrementMetric("numValueStateWithTTLVars") - ttlStates.add(valueStateWithTTL) - stateVariables.add(new StateVariableInfo(stateName, ValueState, true)) - valueStateWithTTL + store match { + case Some(store) => verifyStateVarOperations("get_value_state") + validateTTLConfig(ttlConfig, stateName) + assert(batchTimestampMs.isDefined) + val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName, + keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) + incrementMetric("numValueStateWithTTLVars") + ttlStates.add(valueStateWithTTL) + valueStateWithTTL + case None => + stateVariables.add(new StateVariableInfo(stateName, ValueState, true)) + null + } } override def getQueryInfo(): QueryInfo = currQueryInfo - private lazy val timerState = new TimerStateImpl(store, timeMode, keyEncoder) + private lazy val timerState = new TimerStateImpl(store.get, timeMode, keyEncoder) private def verifyStateVarOperations(operationType: String): Unit = { + assert(store.isDefined, "Cannot call this method on a handle without a state store") if (currState != CREATED) { throw StateStoreErrors.cannotPerformOperationWithInvalidHandleState(operationType, currState.toString) @@ -166,6 +175,7 @@ class StatefulProcessorHandleImpl( } private def verifyTimerOperations(operationType: String): Unit = { + assert(store.isDefined, "Cannot call this method on a handle without a state store") if (timeMode == NoTime) { throw StateStoreErrors.cannotPerformOperationWithInvalidTimeMode(operationType, timeMode.toString) @@ -238,17 +248,21 @@ class StatefulProcessorHandleImpl( */ override def deleteIfExists(stateName: String): Unit = { verifyStateVarOperations("delete_if_exists") - if (store.removeColFamilyIfExists(stateName)) { + if (store.get.removeColFamilyIfExists(stateName)) { incrementMetric("numDeletedStateVars") } } override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { - verifyStateVarOperations("get_list_state") - incrementMetric("numListStateVars") - val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) - stateVariables.add(new StateVariableInfo(stateName, ListState, false)) - resultState + store match { + case Some(store) => + verifyStateVarOperations("get_list_state") + incrementMetric("numListStateVars") + new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) + case None => + stateVariables.add(new StateVariableInfo(stateName, ListState, false)) + null + } } /** @@ -271,27 +285,34 @@ class StatefulProcessorHandleImpl( valEncoder: Encoder[T], ttlConfig: TTLConfig): ListState[T] = { - verifyStateVarOperations("get_list_state") - validateTTLConfig(ttlConfig, stateName) - - assert(batchTimestampMs.isDefined) - val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, - keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) - incrementMetric("numListStateWithTTLVars") - ttlStates.add(listStateWithTTL) - stateVariables.add(new StateVariableInfo(stateName, ListState, true)) - listStateWithTTL + store match { + case Some(store) => verifyStateVarOperations("get_list_state") + validateTTLConfig(ttlConfig, stateName) + assert(batchTimestampMs.isDefined) + val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, + keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) + incrementMetric("numListStateWithTTLVars") + ttlStates.add(listStateWithTTL) + listStateWithTTL + case None => + stateVariables.add(new StateVariableInfo(stateName, ListState, true)) + null + } } override def getMapState[K, V]( stateName: String, userKeyEnc: Encoder[K], valEncoder: Encoder[V]): MapState[K, V] = { - verifyStateVarOperations("get_map_state") - incrementMetric("numMapStateVars") - val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) - stateVariables.add(new StateVariableInfo(stateName, MapState, false)) - resultState + store match { + case Some(store) => + verifyStateVarOperations("get_map_state") + incrementMetric("numMapStateVars") + new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) + case None => + stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) + null + } } override def getMapState[K, V]( @@ -299,16 +320,19 @@ class StatefulProcessorHandleImpl( userKeyEnc: Encoder[K], valEncoder: Encoder[V], ttlConfig: TTLConfig): MapState[K, V] = { - verifyStateVarOperations("get_map_state") - validateTTLConfig(ttlConfig, stateName) - - assert(batchTimestampMs.isDefined) - val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc, - valEncoder, ttlConfig, batchTimestampMs.get) - incrementMetric("numMapStateWithTTLVars") - ttlStates.add(mapStateWithTTL) - stateVariables.add(new StateVariableInfo(stateName, MapState, true)) - mapStateWithTTL + store match { + case Some(store) => verifyStateVarOperations("get_map_state") + validateTTLConfig(ttlConfig, stateName) + assert(batchTimestampMs.isDefined) + val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, + keyEncoder, userKeyEnc, valEncoder, ttlConfig, batchTimestampMs.get) + incrementMetric("numMapStateWithTTLVars") + ttlStates.add(mapStateWithTTL) + mapStateWithTTL + case None => + stateVariables.add(new StateVariableInfo(stateName, MapState, true)) + null + } } private def validateTTLConfig(ttlConfig: TTLConfig, stateName: String): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 28aa08b963d7d..d085581173559 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.streaming +import java.util import java.util.UUID import java.util.concurrent.TimeUnit.NANOSECONDS @@ -83,11 +84,8 @@ case class TransformWithStateExec( initialState: SparkPlan) extends BinaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec { - val operatorProperties: OperatorProperties = - OperatorProperties.create( - sparkContext, - "colFamilyMetadata" - ) + val operatorProperties: util.Map[String, JValue] = + new util.HashMap[String, JValue]() override def operatorStateMetadataVersion: Int = 2 @@ -103,7 +101,7 @@ case class TransformWithStateExec( val operatorPropertiesJson: JValue = ("timeMode" -> JString(timeMode.toString)) ~ ("outputMode" -> JString(outputMode.toString)) ~ - ("stateVariables" -> operatorProperties.value.get("stateVariables")) + ("stateVariables" -> operatorProperties.get("stateVariables")) val json = compact(render(operatorPropertiesJson)) OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) @@ -338,9 +336,6 @@ case class TransformWithStateExec( store.abort() } } - operatorProperties.add(Map - ("stateVariables" -> JArray(processorHandle.stateVariables. - asScala.map(_.jsonValue).toList))) setStoreMetrics(store) setOperatorMetrics() statefulProcessor.close() @@ -378,6 +373,18 @@ case class TransformWithStateExec( validateTimeMode() + val driverProcessorHandle = new StatefulProcessorHandleImpl( + None, getStateInfo.queryRunId, keyEncoder, timeMode, + isStreaming, batchTimestampMs, metrics) + + driverProcessorHandle.setHandleState(StatefulProcessorHandleState.PRE_INIT) + statefulProcessor.setHandle(driverProcessorHandle) + statefulProcessor.init(outputMode, timeMode) + operatorProperties.put("stateVariables", JArray(driverProcessorHandle.stateVariables. + asScala.map(_.jsonValue).toList)) + statefulProcessor.setHandle(null) + driverProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + if (hasInitialState) { val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf) val hadoopConfBroadcast = sparkContext.broadcast( @@ -491,7 +498,7 @@ case class TransformWithStateExec( private def processData(store: StateStore, singleIterator: Iterator[InternalRow]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl( - store, getStateInfo.queryRunId, keyEncoder, timeMode, + Some(store), getStateInfo.queryRunId, keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) @@ -505,7 +512,7 @@ case class TransformWithStateExec( childDataIterator: Iterator[InternalRow], initStateIterator: Iterator[InternalRow]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { - val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, + val processorHandle = new StatefulProcessorHandleImpl(Some(store), getStateInfo.queryRunId, keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index 1d016522309a1..36bfb34edc412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -74,47 +74,6 @@ case class OperatorStateMetadataV1( override def version: Int = 1 } -/** - * Accumulator to store arbitrary Operator properties. - * This accumulator is used to store the properties of an operator that are not - * available on the driver at the time of planning, and will only be known from - * the executor side. - */ -class OperatorProperties(initValue: Map[String, JValue] = Map.empty) - extends AccumulatorV2[Map[String, JValue], Map[String, JValue]] { - - private var _value: Map[String, JValue] = initValue - - override def isZero: Boolean = _value.isEmpty - - override def copy(): AccumulatorV2[Map[String, JValue], Map[String, JValue]] = { - val newAcc = new OperatorProperties - newAcc._value = _value - newAcc - } - - override def reset(): Unit = _value = Map.empty[String, JValue] - - override def add(v: Map[String, JValue]): Unit = _value ++= v - - override def merge(other: AccumulatorV2[Map[String, JValue], Map[String, JValue]]): Unit = { - _value ++= other.value - } - - override def value: Map[String, JValue] = _value -} - -object OperatorProperties { - def create( - sc: SparkContext, - name: String, - initValue: Map[String, JValue] = Map.empty): OperatorProperties = { - val acc = new OperatorProperties(initValue) - acc.register(sc, name = Some(name)) - acc - } -} - // operatorProperties is an arbitrary JSON formatted string that contains // any properties that we would want to store for a particular operator. case class OperatorStateMetadataV2( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala index 1e6136fd38a37..66df9f85cc2f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala @@ -37,7 +37,7 @@ class ListStateSuite extends StateVariableSuiteBase { private def testMapStateWithNullUserKey()(runListOps: ListState[Long] => Unit): Unit = { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val listState: ListState[Long] = handle.getListState[Long]("listState", Encoders.scalaLong) @@ -70,7 +70,7 @@ class ListStateSuite extends StateVariableSuiteBase { test("List state operations for single instance") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val testState: ListState[Long] = handle.getListState[Long]("testState", Encoders.scalaLong) @@ -98,7 +98,7 @@ class ListStateSuite extends StateVariableSuiteBase { test("List state operations for multiple instance") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val testState1: ListState[Long] = handle.getListState[Long]("testState1", Encoders.scalaLong) @@ -136,7 +136,7 @@ class ListStateSuite extends StateVariableSuiteBase { test("List state operations with list, value, another list instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val listState1: ListState[Long] = handle.getListState[Long]("listState1", Encoders.scalaLong) @@ -166,7 +166,7 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs)) @@ -186,7 +186,7 @@ class ListStateSuite extends StateVariableSuiteBase { assert(ttlStateValueIterator.hasNext) // increment batchProcessingTime, or watermark and ensure expired value is not returned - val nextBatchHandle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val nextBatchHandle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), batchTimestampMs = Some(ttlExpirationMs)) @@ -222,7 +222,7 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val batchTimestampMs = 10 - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala index 5b304c55dd5a7..809b28bf049c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala @@ -40,7 +40,7 @@ class MapStateSuite extends StateVariableSuiteBase { test("Map state operations for single instance") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val testState: MapState[String, Double] = @@ -74,7 +74,7 @@ class MapStateSuite extends StateVariableSuiteBase { test("Map state operations for multiple map instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val testState1: MapState[Long, Double] = @@ -113,7 +113,7 @@ class MapStateSuite extends StateVariableSuiteBase { test("Map state operations with list, value, another map instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val mapTestState1: MapState[String, Int] = @@ -174,7 +174,7 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs)) @@ -195,7 +195,7 @@ class MapStateSuite extends StateVariableSuiteBase { assert(ttlStateValueIterator.hasNext) // increment batchProcessingTime, or watermark and ensure expired value is not returned - val nextBatchHandle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val nextBatchHandle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), batchTimestampMs = Some(ttlExpirationMs)) @@ -232,7 +232,7 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val batchTimestampMs = 10 - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala index 52bdb0213c7e5..7d94994b7cce7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala @@ -49,7 +49,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { test(s"value state creation with timeMode=$timeMode should succeed") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode)) assert(handle.getHandleState === StatefulProcessorHandleState.CREATED) handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -90,7 +90,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { "and invalid state should fail") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode)) Seq(StatefulProcessorHandleState.INITIALIZED, @@ -108,7 +108,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { test("registering processing/event time timeouts with None timeMode should fail") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), keyExprEncoder, TimeMode.None()) val ex = intercept[SparkUnsupportedOperationException] { handle.registerTimer(10000L) @@ -144,7 +144,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { test(s"registering timeouts with timeMode=$timeMode should succeed") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode)) handle.setHandleState(StatefulProcessorHandleState.INITIALIZED) assert(handle.getHandleState === StatefulProcessorHandleState.INITIALIZED) @@ -165,7 +165,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { test(s"verify listing of registered timers with timeMode=$timeMode") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode)) handle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED) assert(handle.getHandleState === StatefulProcessorHandleState.DATA_PROCESSED) @@ -205,7 +205,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { test(s"registering timeouts with timeMode=$timeMode and invalid state should fail") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode)) Seq(StatefulProcessorHandleState.CREATED, @@ -222,7 +222,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { test("ttl States are populated for valueState and timeMode=ProcessingTime") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), keyExprEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(10)) @@ -240,7 +240,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { test("ttl States are populated for listState and timeMode=ProcessingTime") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), keyExprEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(10)) @@ -258,7 +258,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { test("ttl States are populated for mapState and timeMode=ProcessingTime") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), keyExprEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(10)) @@ -276,7 +276,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { test("ttl States are not populated for timeMode=None") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), keyExprEncoder, TimeMode.None()) handle.getValueState("testValueState", Encoders.STRING) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index e5875da947a37..138a9e4097927 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -48,7 +48,7 @@ class ValueStateSuite extends StateVariableSuiteBase { test("Implicit key operations") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val stateName = "testState" @@ -92,7 +92,7 @@ class ValueStateSuite extends StateVariableSuiteBase { test("Value state operations for single instance") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -118,7 +118,7 @@ class ValueStateSuite extends StateVariableSuiteBase { test("Value state operations for multiple instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val testState1: ValueState[Long] = handle.getValueState[Long]( @@ -163,7 +163,7 @@ class ValueStateSuite extends StateVariableSuiteBase { test("Value state operations for unsupported type name should fail") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val cfName = "_testState" @@ -203,7 +203,7 @@ class ValueStateSuite extends StateVariableSuiteBase { test("test SQL encoder - Value state operations for Primitive(Double) instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val testState: ValueState[Double] = handle.getValueState[Double]("testState", @@ -229,7 +229,7 @@ class ValueStateSuite extends StateVariableSuiteBase { test("test SQL encoder - Value state operations for Primitive(Long) instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", @@ -255,7 +255,7 @@ class ValueStateSuite extends StateVariableSuiteBase { test("test SQL encoder - Value state operations for case class instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val testState: ValueState[TestClass] = handle.getValueState[TestClass]("testState", @@ -281,7 +281,7 @@ class ValueStateSuite extends StateVariableSuiteBase { test("test SQL encoder - Value state operations for POJO instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) val testState: ValueState[POJOTestClass] = handle.getValueState[POJOTestClass]("testState", @@ -309,7 +309,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs)) @@ -329,7 +329,7 @@ class ValueStateSuite extends StateVariableSuiteBase { assert(ttlStateValueIterator.hasNext) // increment batchProcessingTime, or watermark and ensure expired value is not returned - val nextBatchHandle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val nextBatchHandle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), batchTimestampMs = Some(ttlExpirationMs)) @@ -365,7 +365,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val batchTimestampMs = 10 - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + val handle = new StatefulProcessorHandleImpl(Some(store), UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))