From 3554af22d3a1eb0b96702c26a91c691b3c9e21f2 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 13 Jun 2024 10:38:05 -0700 Subject: [PATCH 1/9] trying to plumb schema through planning rule --- .../streaming/IncrementalExecution.scala | 25 ++- .../execution/streaming/ListStateImpl.scala | 7 +- .../streaming/ListStateImplWithTTL.scala | 8 +- .../execution/streaming/MapStateImpl.scala | 19 +-- .../streaming/MapStateImplWithTTL.scala | 8 +- .../streaming/MicroBatchExecution.scala | 14 +- .../streaming/StateSchemaV3File.scala | 78 ++++++++++ .../StatefulProcessorHandleImpl.scala | 13 ++ .../execution/streaming/StreamExecution.scala | 16 ++ .../streaming/TransformWithStateExec.scala | 21 ++- .../execution/streaming/ValueStateImpl.scala | 18 ++- .../streaming/ValueStateImplWithTTL.scala | 8 +- .../state/HDFSBackedStateStoreProvider.scala | 4 + .../state/RocksDBStateStoreProvider.scala | 10 ++ .../streaming/state/SchemaHelper.scala | 142 +++++++++++++++++- .../streaming/state/StateStore.scala | 50 +++++- .../streaming/statefulOperators.scala | 15 ++ .../streaming/state/MemoryStateStore.scala | 4 + .../streaming/TransformWithStateSuite.scala | 77 ++++++++++ 19 files changed, 501 insertions(+), 36 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 42015a5bd29ee..617da35759da7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -187,6 +187,23 @@ class IncrementalExecution( } } + object PopulateSchemaV3Rule extends SparkPlanPartialRule with Logging { + logError(s"### PopulateSchemaV3Rule, batchId = $currentBatchId") + override val rule: PartialFunction[SparkPlan, SparkPlan] = { + case tws: TransformWithStateExec => + val stateSchemaV3File = new StateSchemaV3File( + hadoopConf, tws.stateSchemaFilePath().toString) + logError(s"### trying to get schema from file: ${tws.stateSchemaFilePath()}") + stateSchemaV3File.getLatest() match { + case Some((_, schemaJValue)) => + logError("### PASSING SCHEMA TO OPERATOR") + logError(s"### schemaJValue: $schemaJValue") + tws.copy(columnFamilyJValue = Some(schemaJValue)) + case None => tws + } + } + } + object StateOpIdRule extends SparkPlanPartialRule { override val rule: PartialFunction[SparkPlan, SparkPlan] = { case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion, @@ -454,16 +471,18 @@ class IncrementalExecution( } override def apply(plan: SparkPlan): SparkPlan = { + logError(s"### applying rules to plan") val planWithStateOpId = plan transform composedRule + val planWithSchema = planWithStateOpId transform PopulateSchemaV3Rule.rule // Need to check before write to metadata because we need to detect add operator // Only check when streaming is restarting and is first batch if (isFirstBatch && currentBatchId != 0) { - checkOperatorValidWithMetadata(planWithStateOpId) + checkOperatorValidWithMetadata(planWithSchema) } // The rule doesn't change the plan but cause the side effect that metadata is written // in the checkpoint directory of stateful operator. - simulateWatermarkPropagation(planWithStateOpId) - planWithStateOpId transform WatermarkPropagationRule.rule + simulateWatermarkPropagation(planWithSchema) + planWithSchema transform WatermarkPropagationRule.rule } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 56c9d2664d9e2..b74afd6f418db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.ListState /** @@ -44,8 +44,9 @@ class ListStateImpl[S]( private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true) + val columnFamilySchema = new ColumnFamilySchemaV1( + stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) + store.createColFamilyIfAbsent(columnFamilySchema) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index dc72f8bcd5600..b5b902ab98245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{ListState, TTLConfig} import org.apache.spark.util.NextIterator @@ -52,11 +52,13 @@ class ListStateImplWithTTL[S]( private lazy val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) + val columnFamilySchema = new ColumnFamilySchemaV1( + stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), true) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true) + store.createColFamilyIfAbsent(columnFamilySchema) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index c58f32ed756db..b9558c1b6c310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} import org.apache.spark.sql.streaming.MapState -import org.apache.spark.sql.types.{BinaryType, StructType} class MapStateImpl[K, V]( store: StateStore, @@ -30,18 +30,15 @@ class MapStateImpl[K, V]( userKeyEnc: Encoder[K], valEncoder: Encoder[V]) extends MapState[K, V] with Logging { - // Pack grouping key and user key together as a prefixed composite key - private val schemaForCompositeKeyRow: StructType = - new StructType() - .add("key", BinaryType) - .add("userKey", BinaryType) - private val schemaForValueRow: StructType = new StructType().add("value", BinaryType) private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = new CompositeKeyStateEncoder( - keySerializer, userKeyEnc, valEncoder, schemaForCompositeKeyRow, stateName) + keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, stateName) - store.createColFamilyIfAbsent(stateName, schemaForCompositeKeyRow, schemaForValueRow, - PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1)) + val columnFamilySchema = new ColumnFamilySchemaV1( + stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) + + store.createColFamilyIfAbsent(columnFamilySchema) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala index 2ab06f36dd5f7..ef54624b9ecb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} -import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{MapState, TTLConfig} import org.apache.spark.util.NextIterator @@ -55,11 +55,13 @@ class MapStateImplWithTTL[K, V]( private val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) + val columnFamilySchema = new ColumnFamilySchemaV1( + stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, - PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1)) + store.createColFamilyIfAbsent(columnFamilySchema) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index f85adf8c34363..70b5909ca5fe6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -906,10 +906,22 @@ class MicroBatchExecution( val shouldWriteMetadatas = execCtx.previousContext match { case Some(prevCtx) if prevCtx.executionPlan.runId == execCtx.executionPlan.runId => - false + false case _ => true } + if (shouldWriteMetadatas) { + execCtx.executionPlan.executedPlan.collect { + case tws: TransformWithStateExec => + val schema = tws.getColumnFamilyJValue() + val metadata = tws.operatorStateMetadata() + val id = metadata.operatorInfo.operatorId + val schemaFile = stateSchemaLogs(id) + logError(s"Writing schema for operator $id at path ${schemaFile.metadataPath}") + if (!schemaFile.add(execCtx.batchId, schema)) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) + } + } execCtx.executionPlan.executedPlan.collect { case s: StateStoreWriter => val metadata = s.operatorStateMetadata() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala new file mode 100644 index 0000000000000..426b03ed3c424 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{InputStream, OutputStream, StringReader} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream} +import org.json4s.JValue +import org.json4s.jackson.JsonMethods +import org.json4s.jackson.JsonMethods.{compact, render} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf + +class StateSchemaV3File( + hadoopConf: Configuration, + path: String, + metadataCacheEnabled: Boolean = false) + extends HDFSMetadataLog[JValue](hadoopConf, path, metadataCacheEnabled) { + + final val MAX_UTF_CHUNK_SIZE = 65535 + def this(sparkSession: SparkSession, path: String) = { + this( + sparkSession.sessionState.newHadoopConf(), + path, + metadataCacheEnabled = sparkSession.sessionState.conf.getConf( + SQLConf.STREAMING_METADATA_CACHE_ENABLED) + ) + } + + override protected def serialize(schema: JValue, out: OutputStream): Unit = { + val json = compact(render(schema)) + val buf = new Array[Char](MAX_UTF_CHUNK_SIZE) + + val outputStream = out.asInstanceOf[FSDataOutputStream] + // DataOutputStream.writeUTF can't write a string at once + // if the size exceeds 65535 (2^16 - 1) bytes. + // Each metadata consists of multiple chunks in schema version 3. + try { + val numMetadataChunks = (json.length - 1) / MAX_UTF_CHUNK_SIZE + 1 + val metadataStringReader = new StringReader(json) + outputStream.writeInt(numMetadataChunks) + (0 until numMetadataChunks).foreach { _ => + val numRead = metadataStringReader.read(buf, 0, MAX_UTF_CHUNK_SIZE) + outputStream.writeUTF(new String(buf, 0, numRead)) + } + outputStream.close() + } catch { + case e: Throwable => + throw e + } + } + + override protected def deserialize(in: InputStream): JValue = { + val buf = new StringBuilder + val inputStream = in.asInstanceOf[FSDataInputStream] + val numKeyChunks = inputStream.readInt() + (0 until numKeyChunks).foreach(_ => buf.append(inputStream.readUTF())) + val json = buf.toString() + JsonMethods.parse(json) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index e1d578fb2e5ca..1ba3cc5a981d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -98,6 +98,9 @@ class StatefulProcessorHandleImpl( private[sql] val stateVariables: util.List[StateVariableInfo] = new util.ArrayList[StateVariableInfo]() + private[sql] val columnFamilySchemas: util.List[ColumnFamilySchema] = + new util.ArrayList[ColumnFamilySchema]() + private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" private def buildQueryInfo(): QueryInfo = { @@ -139,6 +142,8 @@ class StatefulProcessorHandleImpl( new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) + val colFamilySchema = resultState.columnFamilySchema + columnFamilySchemas.add(colFamilySchema) null } } @@ -158,6 +163,8 @@ class StatefulProcessorHandleImpl( valueStateWithTTL case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, true)) + val colFamilySchema = resultState.columnFamilySchema + columnFamilySchemas.add(colFamilySchema) null } } @@ -296,6 +303,8 @@ class StatefulProcessorHandleImpl( listStateWithTTL case None => stateVariables.add(new StateVariableInfo(stateName, ListState, true)) + val colFamilySchema = resultState.columnFamilySchema + columnFamilySchemas.add(colFamilySchema) null } } @@ -311,6 +320,8 @@ class StatefulProcessorHandleImpl( new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) + val colFamilySchema = resultState.columnFamilySchema + columnFamilySchemas.add(colFamilySchema) null } } @@ -331,6 +342,8 @@ class StatefulProcessorHandleImpl( mapStateWithTTL case None => stateVariables.add(new StateVariableInfo(stateName, MapState, true)) + val colFamilySchema = resultState.columnFamilySchema + columnFamilySchemas.add(colFamilySchema) null } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 605f536122f92..c217e5b7f9978 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -244,6 +244,10 @@ abstract class StreamExecution( populateOperatorStateMetadatas(getLatestExecutionContext().executionPlan.executedPlan) } + lazy val stateSchemaLogs: Map[Long, StateSchemaV3File] = { + populateStateSchemaFiles(getLatestExecutionContext().executionPlan.executedPlan) + } + private def populateOperatorStateMetadatas( plan: SparkPlan): Map[Long, OperatorStateMetadataLog] = { plan.flatMap { @@ -256,6 +260,18 @@ abstract class StreamExecution( }.toMap } + private def populateStateSchemaFiles( + plan: SparkPlan): Map[Long, StateSchemaV3File] = { + plan.flatMap { + case s: StateStoreWriter => s.stateInfo.map { info => + val schemaFilePath = s.stateSchemaFilePath() + info.operatorId -> new StateSchemaV3File(sparkSession, + schemaFilePath.toString) + } + case _ => Seq.empty + }.toMap + } + /** Whether all fields of the query have been initialized */ private def isInitialized: Boolean = state.get != INITIALIZING diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index d085581173559..764b583705a2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -81,7 +81,8 @@ case class TransformWithStateExec( initialStateGroupingAttrs: Seq[Attribute], initialStateDataAttrs: Seq[Attribute], initialStateDeserializer: Expression, - initialState: SparkPlan) + initialState: SparkPlan, + columnFamilyJValue: Option[JValue] = None) extends BinaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec { val operatorProperties: util.Map[String, JValue] = @@ -91,6 +92,7 @@ case class TransformWithStateExec( override def shortName: String = "transformWithStateExec" + columnFamilySchemas() /** Metadata of this stateful operator and its states stores. */ override def operatorStateMetadata(): OperatorStateMetadata = { @@ -107,6 +109,20 @@ case class TransformWithStateExec( OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) } + def getColumnFamilyJValue(): JValue = { + val columnFamilySchemas = operatorProperties.get("columnFamilySchemas") + columnFamilySchemas + } + + def columnFamilySchemas(): List[ColumnFamilySchema] = { + val columnFamilySchemas = ColumnFamilySchemaV1.fromJValue(columnFamilyJValue) + columnFamilySchemas.foreach { + case c1: ColumnFamilySchemaV1 => logError(s"### colFamilyName:" + + s"${c1.columnFamilyName}") + } + columnFamilySchemas + } + override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { if (timeMode == ProcessingTime) { // TODO: check if we can return true only if actual timers are registered, or there is @@ -382,6 +398,9 @@ case class TransformWithStateExec( statefulProcessor.init(outputMode, timeMode) operatorProperties.put("stateVariables", JArray(driverProcessorHandle.stateVariables. asScala.map(_.jsonValue).toList)) + operatorProperties.put("columnFamilySchemas", JArray(driverProcessorHandle. + columnFamilySchemas.asScala.map(_.jsonValue).toList)) + statefulProcessor.setHandle(null) driverProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index d916011245c00..50816822e72e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.ValueState /** @@ -32,6 +32,17 @@ import org.apache.spark.sql.streaming.ValueState * @param valEncoder - Spark SQL encoder for value * @tparam S - data type of object that will be stored */ +object ValueStateImpl { + def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = { + new ColumnFamilySchemaV1( + stateName, + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + false) + } +} + class ValueStateImpl[S]( store: StateStore, stateName: String, @@ -42,11 +53,12 @@ class ValueStateImpl[S]( private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) + val columnFamilySchema = new ColumnFamilySchemaV1( + stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + store.createColFamilyIfAbsent(columnFamilySchema) } /** Function to check if state exists. Returns true if present and false otherwise */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index 0ed5a6f29a984..cf98697878574 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.{TTLConfig, ValueState} /** @@ -49,11 +49,13 @@ class ValueStateImplWithTTL[S]( private val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) + val columnFamilySchema = new ColumnFamilySchemaV1( + stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + store.createColFamilyIfAbsent(columnFamilySchema) } /** Function to check if state exists. Returns true if present and false otherwise */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 543cd74c489d0..2ba710d1b05cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -130,6 +130,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName) } + override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilySchemaV1): Unit = { + throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName) + } + // Multiple col families are not supported with HDFSBackedStateStoreProvider. Throw an exception // if the user tries to use a non-default col family. private def assertUseOfDefaultColFamily(colFamilyName: String): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index e7fc9f56dd9eb..c52e69e16d581 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -65,6 +65,16 @@ private[sql] class RocksDBStateStoreProvider RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey))) } + override def createColFamilyIfAbsent( + colFamilyMetadata: ColumnFamilySchemaV1): Unit = { + createColFamilyIfAbsent( + colFamilyMetadata.columnFamilyName, + colFamilyMetadata.keySchema, + colFamilyMetadata.valueSchema, + colFamilyMetadata.keyStateEncoderSpec, + colFamilyMetadata.multipleValuesPerKey) + } + override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = { verify(key != null, "Key cannot be null") val kvEncoder = keyValueEncoderMap.get(colFamilyName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index 2eef3d9fc22ed..71776005ce24a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -17,14 +17,71 @@ package org.apache.spark.sql.execution.streaming.state -import java.io.StringReader +import java.io.{OutputStream, StringReader} -import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream} +import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path} +import org.json4s.{DefaultFormats, JsonAST} +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods +import org.json4s.jackson.JsonMethods.{compact, render} -import org.apache.spark.sql.execution.streaming.MetadataVersionUtil +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils +sealed trait ColumnFamilySchema extends Serializable { + def jsonValue: JsonAST.JObject + + def json: String +} + +case class ColumnFamilySchemaV1( + val columnFamilyName: String, + val keySchema: StructType, + val valueSchema: StructType, + val keyStateEncoderSpec: KeyStateEncoderSpec, + val multipleValuesPerKey: Boolean) extends ColumnFamilySchema { + def jsonValue: JsonAST.JObject = { + ("columnFamilyName" -> JString(columnFamilyName)) ~ + ("keySchema" -> keySchema.json) ~ + ("valueSchema" -> valueSchema.json) ~ + ("keyStateEncoderSpec" -> keyStateEncoderSpec.jsonValue) ~ + ("multipleValuesPerKey" -> JBool(multipleValuesPerKey)) + } + + def json: String = { + compact(render(jsonValue)) + } +} + +object ColumnFamilySchemaV1 { + def fromJson(json: List[Map[String, Any]]): List[ColumnFamilySchema] = { + assert(json.isInstanceOf[List[_]]) + + json.map { colFamilyMap => + new ColumnFamilySchemaV1( + colFamilyMap("columnFamilyName").asInstanceOf[String], + StructType.fromString(colFamilyMap("keySchema").asInstanceOf[String]), + StructType.fromString(colFamilyMap("valueSchema").asInstanceOf[String]), + KeyStateEncoderSpec.fromJson(colFamilyMap("keyStateEncoderSpec") + .asInstanceOf[Map[String, Any]]), + colFamilyMap("multipleValuesPerKey").asInstanceOf[Boolean] + ) + } + } + + def fromJValue(jValue: JValue): List[ColumnFamilySchema] = { + implicit val formats: DefaultFormats.type = DefaultFormats + val deserializedList: List[Any] = jValue.extract[List[Any]] + assert(deserializedList.isInstanceOf[List[_]], + s"Expected List but got ${deserializedList.getClass}") + val columnFamilyMetadatas = deserializedList.asInstanceOf[List[Map[String, Any]]] + // Extract each JValue to StateVariableInfo + ColumnFamilySchemaV1.fromJson(columnFamilyMetadatas) + } +} + /** * Helper classes for reading/writing state schema. */ @@ -68,6 +125,34 @@ object SchemaHelper { } } + class SchemaV3Reader( + stateCheckpointPath: Path, + hadoopConf: org.apache.hadoop.conf.Configuration) { + + private val schemaFilePath = SchemaV3Writer.getSchemaFilePath(stateCheckpointPath) + + private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) + def read: List[ColumnFamilySchema] = { + if (!fm.exists(schemaFilePath)) { + return List.empty + } + val buf = new StringBuilder + val inputStream = fm.open(schemaFilePath) + val numKeyChunks = inputStream.readInt() + (0 until numKeyChunks).foreach(_ => buf.append(inputStream.readUTF())) + val json = buf.toString() + val parsedJson = JsonMethods.parse(json) + + implicit val formats = DefaultFormats + val deserializedList: List[Any] = parsedJson.extract[List[Any]] + assert(deserializedList.isInstanceOf[List[_]], + s"Expected List but got ${deserializedList.getClass}") + val columnFamilyMetadatas = deserializedList.asInstanceOf[List[Map[String, Any]]] + // Extract each JValue to StateVariableInfo + ColumnFamilySchemaV1.fromJson(columnFamilyMetadatas) + } + } + trait SchemaWriter { val version: Int @@ -144,4 +229,55 @@ object SchemaHelper { } } } + + object SchemaV3Writer { + def getSchemaFilePath(stateCheckpointPath: Path): Path = { + new Path(new Path(stateCheckpointPath, "_metadata"), "schema") + } + + def serialize(out: OutputStream, schema: List[ColumnFamilySchema]): Unit = { + val json = schema.map(_.json) + out.write(compact(render(json)).getBytes("UTF-8")) + } + } + /** + * Schema writer for schema version 3. Because this writer writes out ColFamilyMetadatas + * instead of key and value schemas, it is not compatible with the SchemaWriter interface. + */ + class SchemaV3Writer( + stateCheckpointPath: Path, + hadoopConf: org.apache.hadoop.conf.Configuration) { + val version: Int = 3 + + private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) + private val schemaFilePath = SchemaV3Writer.getSchemaFilePath(stateCheckpointPath) + + // 2^16 - 1 bytes + final val MAX_UTF_CHUNK_SIZE = 65535 + def writeSchema(metadatasJson: String): Unit = { + val buf = new Array[Char](MAX_UTF_CHUNK_SIZE) + + if (fm.exists(schemaFilePath)) return + + fm.mkdirs(schemaFilePath.getParent) + val outputStream = fm.createAtomic(schemaFilePath, overwriteIfPossible = false) + // DataOutputStream.writeUTF can't write a string at once + // if the size exceeds 65535 (2^16 - 1) bytes. + // Each metadata consists of multiple chunks in schema version 3. + try { + val numMetadataChunks = (metadatasJson.length - 1) / MAX_UTF_CHUNK_SIZE + 1 + val metadataStringReader = new StringReader(metadatasJson) + outputStream.writeInt(numMetadataChunks) + (0 until numMetadataChunks).foreach { _ => + val numRead = metadataStringReader.read(buf, 0, MAX_UTF_CHUNK_SIZE) + outputStream.writeUTF(new String(buf, 0, numRead)) + } + outputStream.close() + } catch { + case e: Throwable => + outputStream.cancel() + throw e + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 8c2170abe3116..7e0c9e4868899 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -28,6 +28,10 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.json4s.{JInt, JsonAST, JString} +import org.json4s.JsonAST.JObject +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.{SparkContext, SparkEnv, SparkUnsupportedOperationException} import org.apache.spark.internal.{Logging, LogKeys, MDC} @@ -133,6 +137,10 @@ trait StateStore extends ReadStateStore { useMultipleValuesPerKey: Boolean = false, isInternal: Boolean = false): Unit + def createColFamilyIfAbsent( + colFamilyMetadata: ColumnFamilySchemaV1 + ): Unit + /** * Put a new non-null value for a non-null key. Implementations must be aware that the UnsafeRows * in the params can be reused, and must make copies of the data as needed for persistence. @@ -289,9 +297,35 @@ class InvalidUnsafeRowException(error: String) "among restart. For the first case, you can try to restart the application without " + s"checkpoint or use the legacy Spark version to process the streaming state.\n$error", null) -sealed trait KeyStateEncoderSpec +sealed trait KeyStateEncoderSpec { + def jsonValue: JsonAST.JObject + def json: String = compact(render(jsonValue)) +} -case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEncoderSpec +object KeyStateEncoderSpec { + def fromJson(m: Map[String, Any]): KeyStateEncoderSpec = { + // match on type + val keySchema = StructType.fromString(m("keySchema").asInstanceOf[String]) + m("keyStateEncoderType").asInstanceOf[String] match { + case "NoPrefixKeyStateEncoderSpec" => + NoPrefixKeyStateEncoderSpec(keySchema) + case "RangeKeyScanStateEncoderSpec" => + val orderingOrdinals = m("orderingOrdinals"). + asInstanceOf[List[_]].map(_.asInstanceOf[Int]) + RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) + case "PrefixKeyScanStateEncoderSpec" => + val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[Int] + PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) + } + } +} + +case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEncoderSpec { + override def jsonValue: JsonAST.JObject = { + ("keyStateEncoderType" -> JString("NoPrefixKeyStateEncoderSpec")) ~ + ("keySchema" -> JString(keySchema.json)) + } +} case class PrefixKeyScanStateEncoderSpec( keySchema: StructType, @@ -299,6 +333,12 @@ case class PrefixKeyScanStateEncoderSpec( if (numColsPrefixKey == 0 || numColsPrefixKey >= keySchema.length) { throw StateStoreErrors.incorrectNumOrderingColsForPrefixScan(numColsPrefixKey.toString) } + + override def jsonValue: JsonAST.JObject = { + ("keyStateEncoderType" -> JString("PrefixKeyScanStateEncoderSpec")) ~ + ("keySchema" -> JString(keySchema.json)) ~ + ("numColsPrefixKey" -> JInt(numColsPrefixKey)) + } } /** Encodes rows so that they can be range-scanned based on orderingOrdinals */ @@ -308,6 +348,12 @@ case class RangeKeyScanStateEncoderSpec( if (orderingOrdinals.isEmpty || orderingOrdinals.length > keySchema.length) { throw StateStoreErrors.incorrectNumOrderingColsForRangeScan(orderingOrdinals.length.toString) } + + override def jsonValue: JObject = { + ("keyStateEncoderType" -> JString("RangeKeyScanStateEncoderSpec")) ~ + ("keySchema" -> JString(keySchema.json)) ~ + ("orderingOrdinals" -> orderingOrdinals.map(JInt(_))) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index f7c6ffb8fdc47..b6c8a9970021a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -78,6 +78,21 @@ trait StatefulOperator extends SparkPlan { new Path(getStateInfo.checkpointLocation, getStateInfo.operatorId.toString) new Path(new Path(stateCheckpointPath, "_metadata"), "metadata") } + + // /state//0//_metadata/schema + def stateSchemaFilePath(storeName: Option[String] = None): Path = { + def stateInfo = getStateInfo + val stateCheckpointPath = + new Path(getStateInfo.checkpointLocation, + s"${stateInfo.operatorId.toString}") + storeName match { + case Some(storeName) => + val storeNamePath = new Path(stateCheckpointPath, storeName) + new Path(new Path(storeNamePath, "_metadata"), "schema") + case None => + new Path(new Path(stateCheckpointPath, "_metadata"), "schema") + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala index 6a476635a6dbe..fa678e4a4f78f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -40,6 +40,10 @@ class MemoryStateStore extends StateStore() { throw StateStoreErrors.multipleColumnFamiliesNotSupported("MemoryStateStoreProvider") } + override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilySchemaV1): Unit = { + throw StateStoreErrors.removingColumnFamiliesNotSupported("MemoryStateStoreProvider") + } + override def removeColFamilyIfExists(colFamilyName: String): Boolean = { throw StateStoreErrors.removingColumnFamiliesNotSupported("MemoryStateStoreProvider") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index e283ba5c11f34..986006bc5973e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import java.io.File +import java.time.Duration import java.util.UUID import org.apache.spark.SparkRuntimeException @@ -35,6 +36,36 @@ object TransformWithStateSuiteUtils { val NUM_SHUFFLE_PARTITIONS = 5 } +class RunningCountStatefulProcessorWithTTL(ttlConfig: TTLConfig) + extends StatefulProcessor[String, String, (String, String)] + with Logging { + + @transient private var _countState: ValueStateImplWithTTL[Long] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _countState = getHandle + .getValueState("countState", Encoders.scalaLong, ttlConfig) + .asInstanceOf[ValueStateImplWithTTL[Long]] + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + val count = _countState.getOption().getOrElse(0L) + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(count) + Iterator((key, count.toString)) + } + } +} + class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)] with Logging { @transient protected var _countState: ValueState[Long] = _ @@ -854,6 +885,52 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } } + + test("transformWithState - verify that query with ttl enabled after restart fails") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + withTempDir { chkptDir => + val clock = new StreamManualClock + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream( + Trigger.ProcessingTime("1 second"), + triggerClock = clock, + checkpointLocation = chkptDir.getCanonicalPath + ), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + StopStream + ) + + logError(s"### Restarting stream") + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState( + new RunningCountStatefulProcessorWithTTL(TTLConfig(Duration.ofMinutes(1))), + TimeMode.ProcessingTime(), + OutputMode.Append()) + + // verify that query with ttl enabled after restart fails + testStream(result, OutputMode.Append())( + StartStream( + Trigger.ProcessingTime("1 second"), + triggerClock = clock, + checkpointLocation = chkptDir.getCanonicalPath + ), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000) + ) + } + } + } } class TransformWithStateValidationSuite extends StateStoreMetricsTest { From 2adbb02d5f2d924854059175b81bc4129ef902df Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 13 Jun 2024 11:07:54 -0700 Subject: [PATCH 2/9] rule is being applied --- .../streaming/IncrementalExecution.scala | 7 +-- .../streaming/MicroBatchExecution.scala | 57 +++++++++---------- .../streaming/TransformWithStateExec.scala | 7 +-- .../streaming/TransformWithStateSuite.scala | 6 +- 4 files changed, 34 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 617da35759da7..f7bc2fb5f27ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -188,16 +188,12 @@ class IncrementalExecution( } object PopulateSchemaV3Rule extends SparkPlanPartialRule with Logging { - logError(s"### PopulateSchemaV3Rule, batchId = $currentBatchId") override val rule: PartialFunction[SparkPlan, SparkPlan] = { - case tws: TransformWithStateExec => + case tws: TransformWithStateExec if isFirstBatch && currentBatchId != 0 => val stateSchemaV3File = new StateSchemaV3File( hadoopConf, tws.stateSchemaFilePath().toString) - logError(s"### trying to get schema from file: ${tws.stateSchemaFilePath()}") stateSchemaV3File.getLatest() match { case Some((_, schemaJValue)) => - logError("### PASSING SCHEMA TO OPERATOR") - logError(s"### schemaJValue: $schemaJValue") tws.copy(columnFamilyJValue = Some(schemaJValue)) case None => tws } @@ -471,7 +467,6 @@ class IncrementalExecution( } override def apply(plan: SparkPlan): SparkPlan = { - logError(s"### applying rules to plan") val planWithStateOpId = plan transform composedRule val planWithSchema = planWithStateOpId transform PopulateSchemaV3Rule.rule // Need to check before write to metadata because we need to detect add operator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 70b5909ca5fe6..40809f1108865 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -899,39 +899,38 @@ class MicroBatchExecution( */ protected def markMicroBatchEnd(execCtx: MicroBatchExecutionContext): Unit = { watermarkTracker.updateWatermark(execCtx.executionPlan.executedPlan) + val shouldWriteMetadatas = execCtx.previousContext match { + case Some(prevCtx) + if prevCtx.executionPlan.runId == execCtx.executionPlan.runId => + false + case _ => true + } + + if (shouldWriteMetadatas) { + execCtx.executionPlan.executedPlan.collect { + case tws: TransformWithStateExec => + val schema = tws.getColumnFamilyJValue() + val metadata = tws.operatorStateMetadata() + val id = metadata.operatorInfo.operatorId + val schemaFile = stateSchemaLogs(id) + if (!schemaFile.add(execCtx.batchId, schema)) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) + } + } + execCtx.executionPlan.executedPlan.collect { + case s: StateStoreWriter => + val metadata = s.operatorStateMetadata() + val id = metadata.operatorInfo.operatorId + val metadataFile = operatorStateMetadataLogs(id) + if (!metadataFile.add(execCtx.batchId, metadata)) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) + } + } + } execCtx.reportTimeTaken("commitOffsets") { if (!commitLog.add(execCtx.batchId, CommitMetadata(watermarkTracker.currentWatermark))) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) } - val shouldWriteMetadatas = execCtx.previousContext match { - case Some(prevCtx) - if prevCtx.executionPlan.runId == execCtx.executionPlan.runId => - false - case _ => true - } - - if (shouldWriteMetadatas) { - execCtx.executionPlan.executedPlan.collect { - case tws: TransformWithStateExec => - val schema = tws.getColumnFamilyJValue() - val metadata = tws.operatorStateMetadata() - val id = metadata.operatorInfo.operatorId - val schemaFile = stateSchemaLogs(id) - logError(s"Writing schema for operator $id at path ${schemaFile.metadataPath}") - if (!schemaFile.add(execCtx.batchId, schema)) { - throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) - } - } - execCtx.executionPlan.executedPlan.collect { - case s: StateStoreWriter => - val metadata = s.operatorStateMetadata() - val id = metadata.operatorInfo.operatorId - val metadataFile = operatorStateMetadataLogs(id) - if (!metadataFile.add(execCtx.batchId, metadata)) { - throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) - } - } - } } committedOffsets ++= execCtx.endOffsets } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 764b583705a2e..54469780854ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -115,12 +115,7 @@ case class TransformWithStateExec( } def columnFamilySchemas(): List[ColumnFamilySchema] = { - val columnFamilySchemas = ColumnFamilySchemaV1.fromJValue(columnFamilyJValue) - columnFamilySchemas.foreach { - case c1: ColumnFamilySchemaV1 => logError(s"### colFamilyName:" + - s"${c1.columnFamilyName}") - } - columnFamilySchemas + ColumnFamilySchemaV1.fromJValue(columnFamilyJValue) } override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 986006bc5973e..fd42008c941ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -919,14 +919,16 @@ class TransformWithStateSuite extends StateStoreMetricsTest OutputMode.Append()) // verify that query with ttl enabled after restart fails - testStream(result, OutputMode.Append())( + testStream(result2, OutputMode.Append())( StartStream( Trigger.ProcessingTime("1 second"), triggerClock = clock, checkpointLocation = chkptDir.getCanonicalPath ), AddData(inputData, "a"), - AdvanceManualClock(1 * 1000) + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + StopStream ) } } From f4002029c2df113bf5a5a66da0434a2f237b0e1f Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 13 Jun 2024 11:16:41 -0700 Subject: [PATCH 3/9] adding purging --- .../spark/sql/execution/streaming/MicroBatchExecution.scala | 2 +- .../apache/spark/sql/execution/streaming/StreamExecution.scala | 2 ++ .../apache/spark/sql/streaming/TransformWithStateSuite.scala | 1 - 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 40809f1108865..387d9d4a8c6e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -909,10 +909,10 @@ class MicroBatchExecution( if (shouldWriteMetadatas) { execCtx.executionPlan.executedPlan.collect { case tws: TransformWithStateExec => - val schema = tws.getColumnFamilyJValue() val metadata = tws.operatorStateMetadata() val id = metadata.operatorInfo.operatorId val schemaFile = stateSchemaLogs(id) + val schema = tws.getColumnFamilyJValue() if (!schemaFile.add(execCtx.batchId, schema)) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index c217e5b7f9978..63d009ed928c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -715,6 +715,8 @@ abstract class StreamExecution( protected def purgeOldest(): Unit = { operatorStateMetadataLogs.foreach( _._2.purgeOldest(minLogEntriesToMaintain)) + stateSchemaLogs.foreach( + _._2.purgeOldest(minLogEntriesToMaintain)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index fd42008c941ed..2247ad9c321d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -910,7 +910,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest StopStream ) - logError(s"### Restarting stream") val result2 = inputData.toDS() .groupByKey(x => x) .transformWithState( From 45e82c9a380471f27bb9edc7ae03fb7337810caa Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 13 Jun 2024 12:00:30 -0700 Subject: [PATCH 4/9] checking if file exists in add methods --- .../streaming/OperatorStateMetadataLog.scala | 28 +++++++++++++++++++ .../streaming/StateSchemaV3File.scala | 22 +++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala index f77875279384f..8bbd589a12885 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala @@ -66,4 +66,32 @@ class OperatorStateMetadataLog( case "v2" => OperatorStateMetadataV2.deserialize(bufferedReader) } } + + + /** + * Store the metadata for the specified batchId and return `true` if successful. If the batchId's + * metadata has already been stored, this method will return `false`. + */ + override def add(batchId: Long, metadata: OperatorStateMetadata): Boolean = { + require(metadata != null, "'null' metadata cannot written to a metadata log") + val batchMetadataFile = batchIdToPath(batchId) + if (fileManager.exists(batchMetadataFile)) { + fileManager.delete(batchMetadataFile) + } + val res = addNewBatchByStream(batchId) { output => serialize(metadata, output) } + if (metadataCacheEnabled && res) batchCache.put(batchId, metadata) + res + } + + override def addNewBatchByStream(batchId: Long)(fn: OutputStream => Unit): Boolean = { + val batchMetadataFile = batchIdToPath(batchId) + + if (metadataCacheEnabled && batchCache.containsKey(batchId)) { + false + } else { + write(batchMetadataFile, fn) + true + } + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala index 426b03ed3c424..82bab9a5301f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala @@ -75,4 +75,26 @@ class StateSchemaV3File( val json = buf.toString() JsonMethods.parse(json) } + + override def add(batchId: Long, metadata: JValue): Boolean = { + require(metadata != null, "'null' metadata cannot written to a metadata log") + val batchMetadataFile = batchIdToPath(batchId) + if (fileManager.exists(batchMetadataFile)) { + fileManager.delete(batchMetadataFile) + } + val res = addNewBatchByStream(batchId) { output => serialize(metadata, output) } + if (metadataCacheEnabled && res) batchCache.put(batchId, metadata) + res + } + + override def addNewBatchByStream(batchId: Long)(fn: OutputStream => Unit): Boolean = { + val batchMetadataFile = batchIdToPath(batchId) + + if (metadataCacheEnabled && batchCache.containsKey(batchId)) { + false + } else { + write(batchMetadataFile, fn) + true + } + } } From 95d5b2a862e14da493e66589510de7201752e744 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 14 Jun 2024 07:04:10 -0700 Subject: [PATCH 5/9] tests pass --- .../sql/execution/streaming/ListStateImpl.scala | 10 ++++++++++ .../execution/streaming/ListStateImplWithTTL.scala | 10 ++++++++++ .../spark/sql/execution/streaming/MapStateImpl.scala | 10 ++++++++++ .../execution/streaming/MapStateImplWithTTL.scala | 9 +++++++++ .../streaming/StatefulProcessorHandleImpl.scala | 12 +++++++----- .../execution/streaming/ValueStateImplWithTTL.scala | 11 +++++++++++ .../streaming/state/OperatorStateMetadata.scala | 3 --- 7 files changed, 57 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index b74afd6f418db..fe31315eeb020 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -23,6 +23,16 @@ import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSch import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.ListState +object ListStateImpl { + def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = { + new ColumnFamilySchemaV1( + stateName, + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + true) + } +} /** * Provides concrete implementation for list of values associated with a state variable * used in the streaming transformWithState operator. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index b5b902ab98245..89e62c8a5a864 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -23,6 +23,16 @@ import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoP import org.apache.spark.sql.streaming.{ListState, TTLConfig} import org.apache.spark.util.NextIterator +object ListStateImplWithTTL { + def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = { + new ColumnFamilySchemaV1( + stateName, + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA_WITH_TTL, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + true) + } +} /** * Class that provides a concrete implementation for a list state state associated with state * variables (with ttl expiration support) used in the streaming transformWithState operator. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index b9558c1b6c310..9cfc22f6b7a1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -23,6 +23,16 @@ import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSch import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} import org.apache.spark.sql.streaming.MapState +object MapStateImpl { + def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = { + new ColumnFamilySchemaV1( + stateName, + COMPOSITE_KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) + } +} + class MapStateImpl[K, V]( store: StateStore, stateName: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala index ef54624b9ecb0..4a7c939cc78f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala @@ -24,6 +24,15 @@ import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, Pre import org.apache.spark.sql.streaming.{MapState, TTLConfig} import org.apache.spark.util.NextIterator +object MapStateImplWithTTL { + def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = { + new ColumnFamilySchemaV1( + stateName, + COMPOSITE_KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA_WITH_TTL, + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) + } +} /** * Class that provides a concrete implementation for map state associated with state * variables (with ttl expiration support) used in the streaming transformWithState operator. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 1ba3cc5a981d3..27cba0d11eca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -142,7 +142,7 @@ class StatefulProcessorHandleImpl( new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) - val colFamilySchema = resultState.columnFamilySchema + val colFamilySchema = ValueStateImpl.columnFamilySchema(stateName) columnFamilySchemas.add(colFamilySchema) null } @@ -163,7 +163,7 @@ class StatefulProcessorHandleImpl( valueStateWithTTL case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, true)) - val colFamilySchema = resultState.columnFamilySchema + val colFamilySchema = ValueStateImplWithTTL.columnFamilySchema(stateName) columnFamilySchemas.add(colFamilySchema) null } @@ -268,6 +268,8 @@ class StatefulProcessorHandleImpl( new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) case None => stateVariables.add(new StateVariableInfo(stateName, ListState, false)) + val colFamilySchema = ListStateImpl.columnFamilySchema(stateName) + columnFamilySchemas.add(colFamilySchema) null } } @@ -303,7 +305,7 @@ class StatefulProcessorHandleImpl( listStateWithTTL case None => stateVariables.add(new StateVariableInfo(stateName, ListState, true)) - val colFamilySchema = resultState.columnFamilySchema + val colFamilySchema = ListStateImplWithTTL.columnFamilySchema(stateName) columnFamilySchemas.add(colFamilySchema) null } @@ -320,7 +322,7 @@ class StatefulProcessorHandleImpl( new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) - val colFamilySchema = resultState.columnFamilySchema + val colFamilySchema = MapStateImpl.columnFamilySchema(stateName) columnFamilySchemas.add(colFamilySchema) null } @@ -342,7 +344,7 @@ class StatefulProcessorHandleImpl( mapStateWithTTL case None => stateVariables.add(new StateVariableInfo(stateName, MapState, true)) - val colFamilySchema = resultState.columnFamilySchema + val colFamilySchema = MapStateImplWithTTL.columnFamilySchema(stateName) columnFamilySchemas.add(colFamilySchema) null } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index cf98697878574..3318dd29be4a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -22,6 +22,17 @@ import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSch import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.{TTLConfig, ValueState} +object ValueStateImplWithTTL { + def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = { + new ColumnFamilySchemaV1( + stateName, + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA_WITH_TTL, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + false) + } +} + /** * Class that provides a concrete implementation for a single value state associated with state * variables (with ttl expiration support) used in the streaming transformWithState operator. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index 36bfb34edc412..3bc1442371322 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -25,13 +25,10 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FSDataOutputStream, Path} import org.json4s.{Formats, NoTypeHints} -import org.json4s.JsonAST.JValue import org.json4s.jackson.Serialization -import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} -import org.apache.spark.util.AccumulatorV2 /** * Metadata for a state store instance. From c383ba7291ccf31c04c4158339a251b87a528b82 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 14 Jun 2024 07:21:31 -0700 Subject: [PATCH 6/9] adding block to purge any metadata files that are greater than or equal to this batchId --- .../sql/execution/streaming/MicroBatchExecution.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 387d9d4a8c6e0..99d5f24201386 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -907,6 +907,15 @@ class MicroBatchExecution( } if (shouldWriteMetadatas) { + // clean up any batchIds that are greater than or equal to + // the current batchId + execCtx.executionPlan.executedPlan.collect { + case tws: TransformWithStateExec => + val metadata = tws.operatorStateMetadata() + val id = metadata.operatorInfo.operatorId + val metadataFile = operatorStateMetadataLogs(id) + metadataFile.purgeAfter(execCtx.batchId - 1) + } execCtx.executionPlan.executedPlan.collect { case tws: TransformWithStateExec => val metadata = tws.operatorStateMetadata() From 68064f1d24c5061f00b6d35ceea9d8ed29e4d4ea Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 14 Jun 2024 07:35:03 -0700 Subject: [PATCH 7/9] removing duplicate code --- .../spark/sql/execution/streaming/ListStateImpl.scala | 4 +--- .../sql/execution/streaming/ListStateImplWithTTL.scala | 5 +---- .../apache/spark/sql/execution/streaming/MapStateImpl.scala | 6 +----- .../spark/sql/execution/streaming/MapStateImplWithTTL.scala | 5 +---- .../spark/sql/execution/streaming/ValueStateImpl.scala | 4 +--- .../sql/execution/streaming/ValueStateImplWithTTL.scala | 5 +---- 6 files changed, 6 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index fe31315eeb020..429464a5467b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -54,9 +54,7 @@ class ListStateImpl[S]( private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) - val columnFamilySchema = new ColumnFamilySchemaV1( - stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) - store.createColFamilyIfAbsent(columnFamilySchema) + store.createColFamilyIfAbsent(ListStateImpl.columnFamilySchema(stateName)) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index 89e62c8a5a864..969ad8a889fc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -62,13 +62,10 @@ class ListStateImplWithTTL[S]( private lazy val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) - val columnFamilySchema = new ColumnFamilySchemaV1( - stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), true) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(columnFamilySchema) + store.createColFamilyIfAbsent(ListStateImplWithTTL.columnFamilySchema(stateName)) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index 9cfc22f6b7a1a..0d3a0be5cf5e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -44,11 +44,7 @@ class MapStateImpl[K, V]( private val stateTypesEncoder = new CompositeKeyStateEncoder( keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, stateName) - val columnFamilySchema = new ColumnFamilySchemaV1( - stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, - PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) - - store.createColFamilyIfAbsent(columnFamilySchema) + store.createColFamilyIfAbsent(MapStateImpl.columnFamilySchema(stateName)) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala index 4a7c939cc78f3..cb99ccf248f9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala @@ -64,13 +64,10 @@ class MapStateImplWithTTL[K, V]( private val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) - val columnFamilySchema = new ColumnFamilySchemaV1( - stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, - PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(columnFamilySchema) + store.createColFamilyIfAbsent(MapStateImplWithTTL.columnFamilySchema(stateName)) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 50816822e72e8..ea32ccf29bfab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -53,12 +53,10 @@ class ValueStateImpl[S]( private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) - val columnFamilySchema = new ColumnFamilySchemaV1( - stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(columnFamilySchema) + store.createColFamilyIfAbsent(ValueStateImpl.columnFamilySchema(stateName)) } /** Function to check if state exists. Returns true if present and false otherwise */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index 3318dd29be4a1..428bfa1d75776 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -60,13 +60,10 @@ class ValueStateImplWithTTL[S]( private val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) - val columnFamilySchema = new ColumnFamilySchemaV1( - stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(columnFamilySchema) + store.createColFamilyIfAbsent(ValueStateImplWithTTL.columnFamilySchema(stateName)) } /** Function to check if state exists. Returns true if present and false otherwise */ From e935292aae4bf0320c9971f10246907d857161bb Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 14 Jun 2024 08:43:25 -0700 Subject: [PATCH 8/9] validating state variable creation --- .../StatefulProcessorHandleImpl.scala | 30 ++++++++++++++++--- .../streaming/TransformWithStateExec.scala | 7 ++++- .../streaming/TransformWithStateSuite.scala | 9 ++++-- 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 27cba0d11eca4..1d0034be00735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -85,7 +85,8 @@ class StatefulProcessorHandleImpl( timeMode: TimeMode, isStreaming: Boolean = true, batchTimestampMs: Option[Long] = None, - metrics: Map[String, SQLMetric] = Map.empty) + metrics: Map[String, SQLMetric] = Map.empty, + existingColumnFamilies: Map[String, ColumnFamilySchema] = Map.empty) extends StatefulProcessorHandle with Logging { import StatefulProcessorHandleState._ @@ -132,6 +133,19 @@ class StatefulProcessorHandleImpl( def getHandleState: StatefulProcessorHandleState = currState + private def verifyStateVariableCreation(columnFamilySchema: ColumnFamilySchema): Unit = { + columnFamilySchema match { + case c1: ColumnFamilySchemaV1 if existingColumnFamilies.contains(c1.columnFamilyName) => + val existingColumnFamily = existingColumnFamilies(c1.columnFamilyName) + if (existingColumnFamily.json != columnFamilySchema.json) { + throw new RuntimeException( + s"State variable with name ${c1.columnFamilyName} already exists " + + s"with different schema.") + } + case _ => + } + } + override def getValueState[T]( stateName: String, valEncoder: Encoder[T]): ValueState[T] = { @@ -143,6 +157,7 @@ class StatefulProcessorHandleImpl( case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) val colFamilySchema = ValueStateImpl.columnFamilySchema(stateName) + verifyStateVariableCreation(colFamilySchema) columnFamilySchemas.add(colFamilySchema) null } @@ -164,6 +179,7 @@ class StatefulProcessorHandleImpl( case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, true)) val colFamilySchema = ValueStateImplWithTTL.columnFamilySchema(stateName) + verifyStateVariableCreation(colFamilySchema) columnFamilySchemas.add(colFamilySchema) null } @@ -254,9 +270,11 @@ class StatefulProcessorHandleImpl( * @param stateName - name of the state variable */ override def deleteIfExists(stateName: String): Unit = { - verifyStateVarOperations("delete_if_exists") - if (store.get.removeColFamilyIfExists(stateName)) { - incrementMetric("numDeletedStateVars") + if (store.isDefined) { + verifyStateVarOperations("delete_if_exists") + if (store.get.removeColFamilyIfExists(stateName)) { + incrementMetric("numDeletedStateVars") + } } } @@ -269,6 +287,7 @@ class StatefulProcessorHandleImpl( case None => stateVariables.add(new StateVariableInfo(stateName, ListState, false)) val colFamilySchema = ListStateImpl.columnFamilySchema(stateName) + verifyStateVariableCreation(colFamilySchema) columnFamilySchemas.add(colFamilySchema) null } @@ -306,6 +325,7 @@ class StatefulProcessorHandleImpl( case None => stateVariables.add(new StateVariableInfo(stateName, ListState, true)) val colFamilySchema = ListStateImplWithTTL.columnFamilySchema(stateName) + verifyStateVariableCreation(colFamilySchema) columnFamilySchemas.add(colFamilySchema) null } @@ -323,6 +343,7 @@ class StatefulProcessorHandleImpl( case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) val colFamilySchema = MapStateImpl.columnFamilySchema(stateName) + verifyStateVariableCreation(colFamilySchema) columnFamilySchemas.add(colFamilySchema) null } @@ -345,6 +366,7 @@ class StatefulProcessorHandleImpl( case None => stateVariables.add(new StateVariableInfo(stateName, MapState, true)) val colFamilySchema = MapStateImplWithTTL.columnFamilySchema(stateName) + verifyStateVariableCreation(colFamilySchema) columnFamilySchemas.add(colFamilySchema) null } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 54469780854ea..036fbb64e3877 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -384,9 +384,14 @@ case class TransformWithStateExec( validateTimeMode() + val existingColumnFamilies = columnFamilySchemas().map { + case c1: ColumnFamilySchemaV1 => + c1.columnFamilyName -> c1 + }.toMap + val driverProcessorHandle = new StatefulProcessorHandleImpl( None, getStateInfo.queryRunId, keyEncoder, timeMode, - isStreaming, batchTimestampMs, metrics) + isStreaming, batchTimestampMs, metrics, existingColumnFamilies) driverProcessorHandle.setHandleState(StatefulProcessorHandleState.PRE_INIT) statefulProcessor.setHandle(driverProcessorHandle) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 2247ad9c321d2..513bcd7540460 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -926,8 +926,13 @@ class TransformWithStateSuite extends StateStoreMetricsTest ), AddData(inputData, "a"), AdvanceManualClock(1 * 1000), - CheckNewAnswer(("a", "1")), - StopStream + Execute { q => + val e = intercept[Exception] { + q.processAllAvailable() + } + assert(e.getMessage.contains("State variable with name" + + " countState already exists with different schema")) + } ) } } From 52a579e0842fc9cb5e33e9511eb106a8a1fa8058 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 14 Jun 2024 13:26:07 -0700 Subject: [PATCH 9/9] writing schema and metadata in planning rule --- .../streaming/IncrementalExecution.scala | 44 +++++++++++-- .../streaming/MicroBatchExecution.scala | 37 ----------- .../StatefulProcessorHandleImpl.scala | 19 ------ .../streaming/TransformWithStateExec.scala | 65 ++++++++++++------- 4 files changed, 80 insertions(+), 85 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index f7bc2fb5f27ef..1b6b17cf11d9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -21,6 +21,7 @@ import java.util.UUID import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.fs.Path +import org.json4s.JsonAST.JValue import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{BATCH_TIMESTAMP, ERROR} @@ -37,7 +38,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2} +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1, OperatorStateMetadataV2} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -187,15 +188,48 @@ class IncrementalExecution( } } + def writeSchemaAndMetadataFiles( + stateSchemaV3File: StateSchemaV3File, + operatorStateMetadataLog: OperatorStateMetadataLog, + stateSchema: JValue, + operatorStateMetadata: OperatorStateMetadata): Unit = { + operatorStateMetadataLog.purgeAfter(currentBatchId - 1) + if (!stateSchemaV3File.add(currentBatchId, stateSchema)) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(currentBatchId) + } + if (!operatorStateMetadataLog.add(currentBatchId, operatorStateMetadata)) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(currentBatchId) + } + } + object PopulateSchemaV3Rule extends SparkPlanPartialRule with Logging { override val rule: PartialFunction[SparkPlan, SparkPlan] = { - case tws: TransformWithStateExec if isFirstBatch && currentBatchId != 0 => + case tws: TransformWithStateExec if isFirstBatch => val stateSchemaV3File = new StateSchemaV3File( hadoopConf, tws.stateSchemaFilePath().toString) + val operatorStateMetadataLog = new OperatorStateMetadataLog( + hadoopConf, + tws.metadataFilePath().toString + ) stateSchemaV3File.getLatest() match { - case Some((_, schemaJValue)) => - tws.copy(columnFamilyJValue = Some(schemaJValue)) - case None => tws + case Some((_, oldSchema)) => + val newSchema = tws.getSchema() + tws.compareSchemas(oldSchema, newSchema) + writeSchemaAndMetadataFiles( + stateSchemaV3File = stateSchemaV3File, + operatorStateMetadataLog = operatorStateMetadataLog, + stateSchema = newSchema, + operatorStateMetadata = tws.operatorStateMetadata() + ) + tws.copy(columnFamilyJValue = Some(oldSchema)) + case None => + writeSchemaAndMetadataFiles( + stateSchemaV3File = stateSchemaV3File, + operatorStateMetadataLog = operatorStateMetadataLog, + stateSchema = tws.getSchema(), + operatorStateMetadata = tws.operatorStateMetadata() + ) + tws } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 99d5f24201386..20dfcd7c7fd8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -899,43 +899,6 @@ class MicroBatchExecution( */ protected def markMicroBatchEnd(execCtx: MicroBatchExecutionContext): Unit = { watermarkTracker.updateWatermark(execCtx.executionPlan.executedPlan) - val shouldWriteMetadatas = execCtx.previousContext match { - case Some(prevCtx) - if prevCtx.executionPlan.runId == execCtx.executionPlan.runId => - false - case _ => true - } - - if (shouldWriteMetadatas) { - // clean up any batchIds that are greater than or equal to - // the current batchId - execCtx.executionPlan.executedPlan.collect { - case tws: TransformWithStateExec => - val metadata = tws.operatorStateMetadata() - val id = metadata.operatorInfo.operatorId - val metadataFile = operatorStateMetadataLogs(id) - metadataFile.purgeAfter(execCtx.batchId - 1) - } - execCtx.executionPlan.executedPlan.collect { - case tws: TransformWithStateExec => - val metadata = tws.operatorStateMetadata() - val id = metadata.operatorInfo.operatorId - val schemaFile = stateSchemaLogs(id) - val schema = tws.getColumnFamilyJValue() - if (!schemaFile.add(execCtx.batchId, schema)) { - throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) - } - } - execCtx.executionPlan.executedPlan.collect { - case s: StateStoreWriter => - val metadata = s.operatorStateMetadata() - val id = metadata.operatorInfo.operatorId - val metadataFile = operatorStateMetadataLogs(id) - if (!metadataFile.add(execCtx.batchId, metadata)) { - throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) - } - } - } execCtx.reportTimeTaken("commitOffsets") { if (!commitLog.add(execCtx.batchId, CommitMetadata(watermarkTracker.currentWatermark))) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 1d0034be00735..b14eea3e5feb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -133,19 +133,6 @@ class StatefulProcessorHandleImpl( def getHandleState: StatefulProcessorHandleState = currState - private def verifyStateVariableCreation(columnFamilySchema: ColumnFamilySchema): Unit = { - columnFamilySchema match { - case c1: ColumnFamilySchemaV1 if existingColumnFamilies.contains(c1.columnFamilyName) => - val existingColumnFamily = existingColumnFamilies(c1.columnFamilyName) - if (existingColumnFamily.json != columnFamilySchema.json) { - throw new RuntimeException( - s"State variable with name ${c1.columnFamilyName} already exists " + - s"with different schema.") - } - case _ => - } - } - override def getValueState[T]( stateName: String, valEncoder: Encoder[T]): ValueState[T] = { @@ -157,7 +144,6 @@ class StatefulProcessorHandleImpl( case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) val colFamilySchema = ValueStateImpl.columnFamilySchema(stateName) - verifyStateVariableCreation(colFamilySchema) columnFamilySchemas.add(colFamilySchema) null } @@ -179,7 +165,6 @@ class StatefulProcessorHandleImpl( case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, true)) val colFamilySchema = ValueStateImplWithTTL.columnFamilySchema(stateName) - verifyStateVariableCreation(colFamilySchema) columnFamilySchemas.add(colFamilySchema) null } @@ -287,7 +272,6 @@ class StatefulProcessorHandleImpl( case None => stateVariables.add(new StateVariableInfo(stateName, ListState, false)) val colFamilySchema = ListStateImpl.columnFamilySchema(stateName) - verifyStateVariableCreation(colFamilySchema) columnFamilySchemas.add(colFamilySchema) null } @@ -325,7 +309,6 @@ class StatefulProcessorHandleImpl( case None => stateVariables.add(new StateVariableInfo(stateName, ListState, true)) val colFamilySchema = ListStateImplWithTTL.columnFamilySchema(stateName) - verifyStateVariableCreation(colFamilySchema) columnFamilySchemas.add(colFamilySchema) null } @@ -343,7 +326,6 @@ class StatefulProcessorHandleImpl( case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) val colFamilySchema = MapStateImpl.columnFamilySchema(stateName) - verifyStateVariableCreation(colFamilySchema) columnFamilySchemas.add(colFamilySchema) null } @@ -366,7 +348,6 @@ class StatefulProcessorHandleImpl( case None => stateVariables.add(new StateVariableInfo(stateName, MapState, true)) val colFamilySchema = MapStateImplWithTTL.columnFamilySchema(stateName) - verifyStateVariableCreation(colFamilySchema) columnFamilySchemas.add(colFamilySchema) null } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 036fbb64e3877..22528d6f7068f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -92,8 +92,6 @@ case class TransformWithStateExec( override def shortName: String = "transformWithStateExec" - columnFamilySchemas() - /** Metadata of this stateful operator and its states stores. */ override def operatorStateMetadata(): OperatorStateMetadata = { val info = getStateInfo @@ -101,21 +99,44 @@ case class TransformWithStateExec( val stateStoreInfo = Array(StateStoreMetadataV1(StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions)) + val driverProcessorHandle = getDriverProcessorHandle + val stateVariables = JArray(driverProcessorHandle.stateVariables. + asScala.map(_.jsonValue).toList) + + closeProcessorHandle(driverProcessorHandle) val operatorPropertiesJson: JValue = ("timeMode" -> JString(timeMode.toString)) ~ ("outputMode" -> JString(outputMode.toString)) ~ - ("stateVariables" -> operatorProperties.get("stateVariables")) + ("stateVariables" -> stateVariables) val json = compact(render(operatorPropertiesJson)) OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) } - def getColumnFamilyJValue(): JValue = { - val columnFamilySchemas = operatorProperties.get("columnFamilySchemas") + def getSchema(): JValue = { + val driverProcessorHandle = getDriverProcessorHandle + val columnFamilySchemas = JArray(driverProcessorHandle. + columnFamilySchemas.asScala.map(_.jsonValue).toList) + closeProcessorHandle(driverProcessorHandle) columnFamilySchemas } - def columnFamilySchemas(): List[ColumnFamilySchema] = { - ColumnFamilySchemaV1.fromJValue(columnFamilyJValue) + def compareSchemas(oldSchema: JValue, newSchema: JValue): Unit = { + val oldColumnFamilies = ColumnFamilySchemaV1.fromJValue(oldSchema) + val newColumnFamilies = ColumnFamilySchemaV1.fromJValue(newSchema).map { + case c1: ColumnFamilySchemaV1 => + c1.columnFamilyName -> c1 + }.toMap + + oldColumnFamilies.foreach { + case oldColumnFamily: ColumnFamilySchemaV1 => + newColumnFamilies.get(oldColumnFamily.columnFamilyName) match { + case Some(newColumnFamily) if oldColumnFamily.json != newColumnFamily.json => + throw new RuntimeException( + s"State variable with name ${newColumnFamily.columnFamilyName}" + + s" already exists with different schema.") + case _ => // do nothing + } + } } override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { @@ -379,30 +400,26 @@ case class TransformWithStateExec( ) } - override protected def doExecute(): RDD[InternalRow] = { - metrics // force lazy init at driver - - validateTimeMode() - - val existingColumnFamilies = columnFamilySchemas().map { - case c1: ColumnFamilySchemaV1 => - c1.columnFamilyName -> c1 - }.toMap - + protected def getDriverProcessorHandle: StatefulProcessorHandleImpl = { val driverProcessorHandle = new StatefulProcessorHandleImpl( None, getStateInfo.queryRunId, keyEncoder, timeMode, - isStreaming, batchTimestampMs, metrics, existingColumnFamilies) - + isStreaming, batchTimestampMs, metrics) driverProcessorHandle.setHandleState(StatefulProcessorHandleState.PRE_INIT) statefulProcessor.setHandle(driverProcessorHandle) statefulProcessor.init(outputMode, timeMode) - operatorProperties.put("stateVariables", JArray(driverProcessorHandle.stateVariables. - asScala.map(_.jsonValue).toList)) - operatorProperties.put("columnFamilySchemas", JArray(driverProcessorHandle. - columnFamilySchemas.asScala.map(_.jsonValue).toList)) + driverProcessorHandle + } + protected def closeProcessorHandle(processorHandle: StatefulProcessorHandleImpl): Unit = { + statefulProcessor.close() statefulProcessor.setHandle(null) - driverProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + processorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + } + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + validateTimeMode() if (hasInitialState) { val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)