From 4849f20db379e79b6126e23d6b9c7024730d25e0 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Mon, 8 Jul 2024 10:23:50 -0700 Subject: [PATCH 01/22] a base change a draft suite --- .../StatefulProcessorHandleImpl.scala | 6 +++--- .../StatefulProcessorHandleImplBase.scala | 4 +++- .../streaming/TransformWithStateExec.scala | 2 +- .../streaming/TransformWithStateSuite.scala | 20 +++++++++++++++++++ 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 893163a58a1b9..734c1289af207 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -87,7 +87,7 @@ class StatefulProcessorHandleImpl( isStreaming: Boolean = true, batchTimestampMs: Option[Long] = None, metrics: Map[String, SQLMetric] = Map.empty) - extends StatefulProcessorHandleImplBase(timeMode) with Logging { + extends StatefulProcessorHandleImplBase(timeMode, keyEncoder) with Logging { import StatefulProcessorHandleState._ /** @@ -297,8 +297,8 @@ class StatefulProcessorHandleImpl( * actually done. We need this class because we can only collect the schemas after * the StatefulProcessor is initialized. */ -class DriverStatefulProcessorHandleImpl(timeMode: TimeMode) - extends StatefulProcessorHandleImplBase(timeMode) { +class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any]) + extends StatefulProcessorHandleImplBase(timeMode, keyExprEnc) { private[sql] val columnFamilySchemaUtils = ColumnFamilySchemaUtilsV1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala index 3b952967e35d9..12f9ce768e42b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala @@ -16,12 +16,14 @@ */ package org.apache.spark.sql.execution.streaming +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.plans.logical.NoTime import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleState.{INITIALIZED, PRE_INIT, StatefulProcessorHandleState, TIMER_PROCESSED} import org.apache.spark.sql.execution.streaming.state.StateStoreErrors import org.apache.spark.sql.streaming.{StatefulProcessorHandle, TimeMode} -abstract class StatefulProcessorHandleImplBase(timeMode: TimeMode) +abstract class StatefulProcessorHandleImplBase( + timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any]) extends StatefulProcessorHandle { protected var currState: StatefulProcessorHandleState = PRE_INIT diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index acf46df9cc1fa..7ae8b5d1eb63e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -99,7 +99,7 @@ case class TransformWithStateExec( * @return a new instance of the driver processor handle */ private def getDriverProcessorHandle: DriverStatefulProcessorHandleImpl = { - val driverProcessorHandle = new DriverStatefulProcessorHandleImpl(timeMode) + val driverProcessorHandle = new DriverStatefulProcessorHandleImpl(timeMode, keyEncoder) driverProcessorHandle.setHandleState(StatefulProcessorHandleState.PRE_INIT) statefulProcessor.setHandle(driverProcessorHandle) statefulProcessor.init(outputMode, timeMode) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 5e408dc999f82..6c8e09bf7ba36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -20,9 +20,12 @@ package org.apache.spark.sql.streaming import java.io.File import java.util.UUID +import org.json4s.JsonAST.JString + import org.apache.spark.SparkRuntimeException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Encoders} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} @@ -899,3 +902,20 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest { ) } } + +class TransformWithStateSchemaSuite extends StateStoreMetricsTest { + + test("schema") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + StateTypesEncoder(keySerializer = encoderFor(Encoders.scalaInt).createSerializer(), + valEncoder = Encoders.STRING, stateName = "someState", hasTtl = false) + + val keyExprEncoderSer = encoderFor(Encoders.scalaInt).schema + println("keyExprEncoder here: " + JString(keyExprEncoderSer.json)) + println("valueEncoder here: " + JString(Encoders.STRING.schema.json)) + } + } +} From 2bbd2cece16bf78282d21ded2e294e3b012af910 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Mon, 8 Jul 2024 11:08:52 -0700 Subject: [PATCH 02/22] working version, will write test suites and test for composite types --- .../streaming/ColumnFamilySchemaUtils.scala | 40 ++++++++++--------- .../streaming/StateTypesEncoderUtils.scala | 28 +++++++++++++ .../StatefulProcessorHandleImpl.scala | 12 +++--- .../StatefulProcessorHandleImplBase.scala | 3 +- .../streaming/TransformWithStateSuite.scala | 8 ++++ 5 files changed, 64 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala index feced6810d3a3..8cafbba3491f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala @@ -17,23 +17,30 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder -import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema._ import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchema, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec} trait ColumnFamilySchemaUtils { def getValueStateSchema[T]( stateName: String, + keyEncoder: ExpressionEncoder[Any], + valEncoder: Encoder[T], hasTtl: Boolean): ColumnFamilySchema def getListStateSchema[T]( stateName: String, + keyEncoder: ExpressionEncoder[Any], + valEncoder: Encoder[T], hasTtl: Boolean): ColumnFamilySchema def getMapStateSchema[K, V]( stateName: String, + keyEncoder: ExpressionEncoder[Any], userKeyEnc: Encoder[K], + valEncoder: Encoder[V], hasTtl: Boolean): ColumnFamilySchema } @@ -41,44 +48,39 @@ object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils { def getValueStateSchema[T]( stateName: String, + keyEncoder: ExpressionEncoder[Any], + valEncoder: Encoder[T], hasTtl: Boolean): ColumnFamilySchemaV1 = { new ColumnFamilySchemaV1( stateName, - KEY_ROW_SCHEMA, - if (hasTtl) { - VALUE_ROW_SCHEMA_WITH_TTL - } else { - VALUE_ROW_SCHEMA - }, + keyEncoder.schema, + getValueSchemaWithTTL(valEncoder.schema, hasTtl), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) } def getListStateSchema[T]( stateName: String, + keyEncoder: ExpressionEncoder[Any], + valEncoder: Encoder[T], hasTtl: Boolean): ColumnFamilySchemaV1 = { new ColumnFamilySchemaV1( stateName, - KEY_ROW_SCHEMA, - if (hasTtl) { - VALUE_ROW_SCHEMA_WITH_TTL - } else { - VALUE_ROW_SCHEMA - }, + keyEncoder.schema, + getValueSchemaWithTTL(valEncoder.schema, hasTtl), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) } def getMapStateSchema[K, V]( stateName: String, + keyEncoder: ExpressionEncoder[Any], userKeyEnc: Encoder[K], + valEncoder: Encoder[V], hasTtl: Boolean): ColumnFamilySchemaV1 = { + val compositeKeySchema = getCompositeKeySchema(keyEncoder.schema, userKeyEnc.schema) new ColumnFamilySchemaV1( stateName, - COMPOSITE_KEY_ROW_SCHEMA, - if (hasTtl) { - VALUE_ROW_SCHEMA_WITH_TTL - } else { - VALUE_ROW_SCHEMA - }, + compositeKeySchema, + getValueSchemaWithTTL(valEncoder.schema, hasTtl), PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), Some(userKeyEnc.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index ed881b49ec1e9..f308208950dcd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -26,6 +26,11 @@ import org.apache.spark.sql.execution.streaming.state.StateStoreErrors import org.apache.spark.sql.types.{BinaryType, LongType, StructType} object TransformWithStateKeyValueRowSchema { + /** + * The following are the key/value row schema used in StateStore layer. + * Key/value rows will be serialized into Binary format in `StateTypesEncoder`. + * The "real" key/value row schema will be written into state schema metadata. + */ val KEY_ROW_SCHEMA: StructType = new StructType().add("key", BinaryType) val COMPOSITE_KEY_ROW_SCHEMA: StructType = new StructType() .add("key", BinaryType) @@ -35,6 +40,29 @@ object TransformWithStateKeyValueRowSchema { val VALUE_ROW_SCHEMA_WITH_TTL: StructType = new StructType() .add("value", BinaryType) .add("ttlExpirationMs", LongType) + + /** + * Helper function for passing the key/value schema to write to state schema metadata. + * Return value schema with additional TTL column if TTL is enabled. + * + * @param schema Value Schema returned by value encoder that user passed in + * @param hasTTL TTL enabled or not + * @return a schema with additional TTL column if TTL is enabled. + */ + def getValueSchemaWithTTL(schema: StructType, hasTTL: Boolean): StructType = { + if (hasTTL) { + new StructType(schema.fields).add("ttlExpirationMs", LongType) + } else schema + } + + /** + * Given grouping key and user key schema, return the schema of the composite key. + */ + def getCompositeKeySchema( + groupingKeySchema: StructType, + userKeySchema: StructType): StructType = { + new StructType(groupingKeySchema.fields ++ userKeySchema.fields) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 734c1289af207..aa16886624db2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -322,7 +322,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getValueStateSchema(stateName, false) + getValueStateSchema(stateName, keyExprEnc, valEncoder, false) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[ValueState[T]] } @@ -344,7 +344,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi ttlConfig: TTLConfig): ValueState[T] = { verifyStateVarOperations("get_value_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getValueStateSchema(stateName, true) + getValueStateSchema(stateName, keyExprEnc, valEncoder, true) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[ValueState[T]] } @@ -362,7 +362,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { verifyStateVarOperations("get_list_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getListStateSchema(stateName, false) + getListStateSchema(stateName, keyExprEnc, valEncoder, false) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[ListState[T]] } @@ -384,7 +384,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi ttlConfig: TTLConfig): ListState[T] = { verifyStateVarOperations("get_list_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getListStateSchema(stateName, true) + getListStateSchema(stateName, keyExprEnc, valEncoder, true) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[ListState[T]] } @@ -406,7 +406,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi valEncoder: Encoder[V]): MapState[K, V] = { verifyStateVarOperations("get_map_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getMapStateSchema(stateName, userKeyEnc, false) + getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, false) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[MapState[K, V]] } @@ -430,7 +430,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi ttlConfig: TTLConfig): MapState[K, V] = { verifyStateVarOperations("get_map_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getMapStateSchema(stateName, userKeyEnc, true) + getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, true) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[MapState[K, V]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala index 12f9ce768e42b..64d87073ccf9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala @@ -23,8 +23,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStoreErrors import org.apache.spark.sql.streaming.{StatefulProcessorHandle, TimeMode} abstract class StatefulProcessorHandleImplBase( - timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any]) - extends StatefulProcessorHandle { + timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any]) extends StatefulProcessorHandle { protected var currState: StatefulProcessorHandleState = PRE_INIT diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 6c8e09bf7ba36..abb954e3a6145 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogChec import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types._ object TransformWithStateSuiteUtils { val NUM_SHUFFLE_PARTITIONS = 5 @@ -916,6 +917,13 @@ class TransformWithStateSchemaSuite extends StateStoreMetricsTest { val keyExprEncoderSer = encoderFor(Encoders.scalaInt).schema println("keyExprEncoder here: " + JString(keyExprEncoderSer.json)) println("valueEncoder here: " + JString(Encoders.STRING.schema.json)) + println("composite schema: " + + new StructType().add("key", BinaryType) + .add("userKey", BinaryType)) + val keySchema = new StructType().add("key", BinaryType) + val userkeySchema = new StructType().add("userkeySchema", BinaryType) + println("composite schema copy: " + + StructType(keySchema.fields ++ userkeySchema.fields)) } } } From 4f5185a1c1752b3a2bedc828419b8f432912f47a Mon Sep 17 00:00:00 2001 From: jingz-db Date: Mon, 8 Jul 2024 14:31:03 -0700 Subject: [PATCH 03/22] a suite with composite type, why key encoder spec overwritten --- .../streaming/ColumnFamilySchemaUtils.scala | 4 +- .../streaming/StateTypesEncoderUtils.scala | 22 ++-- .../StatefulProcessorHandleImpl.scala | 2 +- .../streaming/state/StateStore.scala | 4 +- .../streaming/TransformWithStateSuite.scala | 112 +++++++++++++++--- 5 files changed, 114 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala index 8cafbba3491f6..cd56986f23b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala @@ -53,7 +53,7 @@ object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils { hasTtl: Boolean): ColumnFamilySchemaV1 = { new ColumnFamilySchemaV1( stateName, - keyEncoder.schema, + getKeySchema(keyEncoder.schema), getValueSchemaWithTTL(valEncoder.schema, hasTtl), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) } @@ -65,7 +65,7 @@ object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils { hasTtl: Boolean): ColumnFamilySchemaV1 = { new ColumnFamilySchemaV1( stateName, - keyEncoder.schema, + getKeySchema(keyEncoder.schema), getValueSchemaWithTTL(valEncoder.schema, hasTtl), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index f308208950dcd..a802ef63576f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -41,18 +41,24 @@ object TransformWithStateKeyValueRowSchema { .add("value", BinaryType) .add("ttlExpirationMs", LongType) + /** Helper functions for passing the key/value schema to write to state schema metadata. */ + + /** + * Return key schema with key column name. + */ + def getKeySchema(schema: StructType): StructType = { + new StructType().add("key", schema) + } + /** - * Helper function for passing the key/value schema to write to state schema metadata. * Return value schema with additional TTL column if TTL is enabled. - * - * @param schema Value Schema returned by value encoder that user passed in - * @param hasTTL TTL enabled or not - * @return a schema with additional TTL column if TTL is enabled. */ def getValueSchemaWithTTL(schema: StructType, hasTTL: Boolean): StructType = { - if (hasTTL) { + val valSchema = if (hasTTL) { new StructType(schema.fields).add("ttlExpirationMs", LongType) } else schema + new StructType() + .add("value", valSchema) } /** @@ -61,7 +67,9 @@ object TransformWithStateKeyValueRowSchema { def getCompositeKeySchema( groupingKeySchema: StructType, userKeySchema: StructType): StructType = { - new StructType(groupingKeySchema.fields ++ userKeySchema.fields) + new StructType() + .add("key", new StructType(groupingKeySchema.fields)) + .add("userKey", new StructType(userKeySchema.fields)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index aa16886624db2..5b6abdd5fd1e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -430,7 +430,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi ttlConfig: TTLConfig): MapState[K, V] = { verifyStateVarOperations("get_map_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, true) + getMapStateSchema(stateName, keyExprEnc, valEncoder, userKeyEnc, true) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[MapState[K, V]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 41167a6c917d7..5ed4c76139ec1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -298,8 +298,8 @@ object KeyStateEncoderSpec { asInstanceOf[List[_]].map(_.asInstanceOf[Int]) RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) case "PrefixKeyScanStateEncoderSpec" => - val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[Int] - PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) + val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[BigInt] + PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey.toInt) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index abb954e3a6145..a4594b5abad4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -20,16 +20,15 @@ package org.apache.spark.sql.streaming import java.io.File import java.util.UUID -import org.json4s.JsonAST.JString +import org.apache.hadoop.fs.Path import org.apache.spark.SparkRuntimeException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Encoders} -import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreMultipleColumnFamiliesNotSupportedException} +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, POJOTestClass, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreMultipleColumnFamiliesNotSupportedException, TestClass} import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -310,6 +309,21 @@ class RunningCountStatefulProcessorWithError extends RunningCountStatefulProcess } } +class StatefulProcessorWithCompositeTypes extends RunningCountStatefulProcessor { + @transient private var _listState: ListState[TestClass] = _ + @transient private var _mapState: MapState[String, POJOTestClass] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) + _listState = getHandle.getListState[TestClass]( + "listState", Encoders.product[TestClass]) + _mapState = getHandle.getMapState[String, POJOTestClass]( + "mapState", Encoders.STRING, Encoders.bean(classOf[POJOTestClass])) + } +} + /** * Class that adds tests for transformWithState stateful streaming operator */ @@ -906,24 +920,86 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest { class TransformWithStateSchemaSuite extends StateStoreMetricsTest { - test("schema") { + import testImplicits._ + + test("transformWithState - verify StateSchemaV3 writes correct SQL schema of key/value") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - StateTypesEncoder(keySerializer = encoderFor(Encoders.scalaInt).createSerializer(), - valEncoder = Encoders.STRING, stateName = "someState", hasTtl = false) - - val keyExprEncoderSer = encoderFor(Encoders.scalaInt).schema - println("keyExprEncoder here: " + JString(keyExprEncoderSer.json)) - println("valueEncoder here: " + JString(Encoders.STRING.schema.json)) - println("composite schema: " + - new StructType().add("key", BinaryType) - .add("userKey", BinaryType)) - val keySchema = new StructType().add("key", BinaryType) - val userkeySchema = new StructType().add("userkeySchema", BinaryType) - println("composite schema copy: " + - StructType(keySchema.fields ++ userkeySchema.fields)) + withTempDir { checkpointDir => + val metadataPathPostfix = "state/0/default/_metadata" + val stateSchemaPath = new Path(checkpointDir.toString, + s"$metadataPathPostfix/schema/0") + val hadoopConf = spark.sessionState.newHadoopConf() + val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) + + val schema0 = ColumnFamilySchemaV1( + "countState", + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType)), + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + None + ) + val schema1 = ColumnFamilySchemaV1( + "listState", + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType() + .add("id", LongType) + .add("name", StringType)), + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + None + ) + val schema2 = ColumnFamilySchemaV1( + "mapState", + new StructType() + .add("key", new StructType().add("value", StringType)) + .add("userKey", new StructType().add("value", StringType)), + new StructType().add("value", + new StructType() + .add("id", IntegerType) + .add("name", StringType)), + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), + Option(new StructType().add("value", StringType)) + ) + println("print out schema0: " + schema0) + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new StatefulProcessorWithCompositeTypes(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "1"), ("b", "1")), + Execute { q => + val schemaFilePath = fm.list(stateSchemaPath).toSeq.head.getPath + val ssv3 = new StateSchemaV3File(hadoopConf, new Path(checkpointDir.toString, + metadataPathPostfix).toString) + val colFamilySeq = ssv3.deserialize(fm.open(schemaFilePath)) + + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numValueStateVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numListStateVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numMapStateVars").toInt) + + assert(colFamilySeq.length == 3) + assert(colFamilySeq.toSet == Set( + schema0, schema1, schema2 + )) + }, + StopStream + ) + } } } } From 00741ff6a088198790d624d9a17b7f9c1385c79e Mon Sep 17 00:00:00 2001 From: jingz-db Date: Mon, 8 Jul 2024 18:00:54 -0700 Subject: [PATCH 04/22] fix suites & add TTL suites --- .../streaming/state/StateStore.scala | 5 +- .../streaming/TransformWithStateSuite.scala | 11 +- .../TransformWithValueStateTTLSuite.scala | 142 +++++++++++++++++- 3 files changed, 144 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 5ed4c76139ec1..a244e841de5a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ThreadUtils, Utils} @@ -292,14 +293,14 @@ object KeyStateEncoderSpec { // match on type m("keyStateEncoderType").asInstanceOf[String] match { case "NoPrefixKeyStateEncoderSpec" => - NoPrefixKeyStateEncoderSpec(keySchema) + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA) case "RangeKeyScanStateEncoderSpec" => val orderingOrdinals = m("orderingOrdinals"). asInstanceOf[List[_]].map(_.asInstanceOf[Int]) RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) case "PrefixKeyScanStateEncoderSpec" => val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[BigInt] - PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey.toInt) + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, numColsPrefixKey.toInt) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index a4594b5abad4d..10dc7e9c4f898 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -939,7 +939,7 @@ class TransformWithStateSchemaSuite extends StateStoreMetricsTest { new StructType().add("key", new StructType().add("value", StringType)), new StructType().add("value", - new StructType().add("value", LongType)), + new StructType().add("value", LongType, false)), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), None ) @@ -949,7 +949,7 @@ class TransformWithStateSchemaSuite extends StateStoreMetricsTest { new StructType().add("value", StringType)), new StructType().add("value", new StructType() - .add("id", LongType) + .add("id", LongType, false) .add("name", StringType)), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), None @@ -961,12 +961,11 @@ class TransformWithStateSchemaSuite extends StateStoreMetricsTest { .add("userKey", new StructType().add("value", StringType)), new StructType().add("value", new StructType() - .add("id", IntegerType) + .add("id", IntegerType, false) .add("name", StringType)), PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), Option(new StructType().add("value", StringType)) ) - println("print out schema0: " + schema0) val inputData = MemoryStream[String] val result = inputData.toDS() @@ -993,9 +992,9 @@ class TransformWithStateSchemaSuite extends StateStoreMetricsTest { q.lastProgress.stateOperators.head.customMetrics.get("numMapStateVars").toInt) assert(colFamilySeq.length == 3) - assert(colFamilySeq.toSet == Set( + assert(colFamilySeq.map(_.toString).toSet == Set( schema0, schema1, schema2 - )) + ).map(_.toString)) }, StopStream ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 54004b419f759..3f7a0be5759c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -19,12 +19,16 @@ package org.apache.spark.sql.streaming import java.time.Duration +import org.apache.hadoop.fs.Path + import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders -import org.apache.spark.sql.execution.streaming.{MemoryStream, ValueStateImpl, ValueStateImplWithTTL} -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, ValueStateImpl, ValueStateImplWithTTL} +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaV3File} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types._ object TTLInputProcessFunction { def processRow( @@ -111,15 +115,15 @@ class ValueStateTTLProcessor(ttlConfig: TTLConfig) } } -case class MultipleValueStatesTTLProcessor( +class MultipleValueStatesTTLProcessor( ttlKey: String, noTtlKey: String, ttlConfig: TTLConfig) extends StatefulProcessor[String, InputEvent, OutputEvent] with Logging { - @transient private var _valueStateWithTTL: ValueStateImplWithTTL[Int] = _ - @transient private var _valueStateWithoutTTL: ValueStateImpl[Int] = _ + @transient var _valueStateWithTTL: ValueStateImplWithTTL[Int] = _ + @transient var _valueStateWithoutTTL: ValueStateImpl[Int] = _ override def init( outputMode: OutputMode, @@ -160,6 +164,28 @@ case class MultipleValueStatesTTLProcessor( } } +class TTLProcessorWithCompositeTypes( + ttlKey: String, + noTtlKey: String, + ttlConfig: TTLConfig) + extends MultipleValueStatesTTLProcessor( + ttlKey: String, noTtlKey: String, ttlConfig: TTLConfig) { + @transient private var _listStateWithTTL: ListStateImplWithTTL[Int] = _ + @transient private var _mapStateWithTTL: MapStateImplWithTTL[Int, String] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + super.init(outputMode, timeMode) + _listStateWithTTL = getHandle + .getListState("listState", Encoders.scalaInt, ttlConfig) + .asInstanceOf[ListStateImplWithTTL[Int]] + _mapStateWithTTL = getHandle + .getMapState("mapState", Encoders.scalaInt, Encoders.STRING, ttlConfig) + .asInstanceOf[MapStateImplWithTTL[Int, String]] + } +} + class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { import testImplicits._ @@ -181,7 +207,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { val result = inputStream.toDS() .groupByKey(x => x.key) .transformWithState( - MultipleValueStatesTTLProcessor(ttlKey, noTtlKey, ttlConfig), + new MultipleValueStatesTTLProcessor(ttlKey, noTtlKey, ttlConfig), TimeMode.ProcessingTime(), OutputMode.Append()) @@ -225,4 +251,108 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { ) } } + + test("verify StateSchemaV3 writes correct SQL schema of key/value and with TTL") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val metadataPathPostfix = "state/0/default/_metadata" + val stateSchemaPath = new Path(checkpointDir.toString, + s"$metadataPathPostfix/schema/0") + val hadoopConf = spark.sessionState.newHadoopConf() + val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) + + val schema0 = ColumnFamilySchemaV1( + "valueState", + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType, false) + .add("ttlExpirationMs", LongType)), + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + None + ) + val schema1 = ColumnFamilySchemaV1( + "listState", + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType() + .add("value", IntegerType, false) + .add("ttlExpirationMs", LongType)), + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + None + ) + val schema2 = ColumnFamilySchemaV1( + "mapState", + new StructType() + .add("key", new StructType().add("value", StringType)) + .add("userKey", new StructType().add("value", StringType)), + new StructType().add("value", + new StructType() + .add("value", IntegerType, false) + .add("ttlExpirationMs", LongType)), + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), + Option(new StructType().add("value", StringType)) + ) + + val ttlKey = "k1" + val noTtlKey = "k2" + val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new TTLProcessorWithCompositeTypes(ttlKey, noTtlKey, ttlConfig), + TimeMode.ProcessingTime(), + OutputMode.Append()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, + checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputStream, InputEvent(ttlKey, "put", 1)), + AddData(inputStream, InputEvent(noTtlKey, "put", 2)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + Execute { q => + println("last progress:" + q.lastProgress) + val schemaFilePath = fm.list(stateSchemaPath).toSeq.head.getPath + val ssv3 = new StateSchemaV3File(hadoopConf, new Path(checkpointDir.toString, + metadataPathPostfix).toString) + val colFamilySeq = ssv3.deserialize(fm.open(schemaFilePath)) + + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numValueStateVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics + .get("numValueStateWithTTLVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics + .get("numListStateWithTTLVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics + .get("numMapStateWithTTLVars").toInt) + + // TODO when there are two state var with the same name, + // only one schema file is preserved + assert(colFamilySeq.length == 3) + /* + assert(colFamilySeq.map(_.toString).toSet == Set( + schema0, schema1, schema2 + ).map(_.toString)) */ + + assert(colFamilySeq(1).toString == schema1.toString) + assert(colFamilySeq(2).toString == schema2.toString) + // The remaining schema file is the one without ttl + // assert(colFamilySeq.head.toString == schema0.toString) + }, + StopStream + ) + } + } + } } From 3691a16d051b9d36813278c2cecce1be9aa3e08e Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 9 Jul 2024 09:17:47 -0700 Subject: [PATCH 05/22] feedback --- .../streaming/ColumnFamilySchemaUtils.scala | 6 +- .../streaming/IncrementalExecution.scala | 4 ++ .../StatefulProcessorHandleImpl.scala | 65 ------------------- .../StreamingSymmetricHashJoinExec.scala | 3 +- .../streaming/TransformWithStateExec.scala | 17 +++-- .../streaming/state/SchemaHelper.scala | 2 +- .../StateSchemaCompatibilityChecker.scala | 1 - .../streaming/state/StateSchemaV3File.scala | 16 +++-- .../streaming/state/StateStore.scala | 2 +- .../streaming/state/StateStoreSuite.scala | 26 ++++++++ 10 files changed, 55 insertions(+), 87 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala index cd56986f23b4e..9fdefe32f1e41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala @@ -51,7 +51,7 @@ object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils { keyEncoder: ExpressionEncoder[Any], valEncoder: Encoder[T], hasTtl: Boolean): ColumnFamilySchemaV1 = { - new ColumnFamilySchemaV1( + ColumnFamilySchemaV1( stateName, getKeySchema(keyEncoder.schema), getValueSchemaWithTTL(valEncoder.schema, hasTtl), @@ -63,7 +63,7 @@ object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils { keyEncoder: ExpressionEncoder[Any], valEncoder: Encoder[T], hasTtl: Boolean): ColumnFamilySchemaV1 = { - new ColumnFamilySchemaV1( + ColumnFamilySchemaV1( stateName, getKeySchema(keyEncoder.schema), getValueSchemaWithTTL(valEncoder.schema, hasTtl), @@ -77,7 +77,7 @@ object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils { valEncoder: Encoder[V], hasTtl: Boolean): ColumnFamilySchemaV1 = { val compositeKeySchema = getCompositeKeySchema(keyEncoder.schema, userKeyEnc.schema) - new ColumnFamilySchemaV1( + ColumnFamilySchemaV1( stateName, compositeKeySchema, getValueSchemaWithTTL(valEncoder.schema, hasTtl), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 7e7e9d76081e1..c65d35bb6c3d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -85,6 +85,10 @@ class IncrementalExecution( .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) + /** + * This value dictates which schema format version the state schema should be written in + * for all operators other than TransformWithState. + */ private val STATE_SCHEMA_DEFAULT_VERSION: Int = 2 /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 5b6abdd5fd1e0..65b435b5c692c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -309,16 +309,6 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi def getColumnFamilySchemas: Map[String, ColumnFamilySchema] = columnFamilySchemas.toMap - /** - * Function to add the ValueState schema to the list of column family schemas. - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @tparam T - type of state variable - * @return - instance of ValueState of type T that can be used to store state persistently - */ override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. @@ -327,17 +317,6 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi null.asInstanceOf[ValueState[T]] } - /** - * Function to add the ValueStateWithTTL schema to the list of column family schemas. - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @param ttlConfig - the ttl configuration (time to live duration etc.) - * @tparam T - type of state variable - * @return - instance of ValueState of type T that can be used to store state persistently - */ override def getValueState[T]( stateName: String, valEncoder: Encoder[T], @@ -349,16 +328,6 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi null.asInstanceOf[ValueState[T]] } - /** - * Function to add the ListState schema to the list of column family schemas. - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @tparam T - type of state variable - * @return - instance of ListState of type T that can be used to store state persistently - */ override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { verifyStateVarOperations("get_list_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. @@ -367,17 +336,6 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi null.asInstanceOf[ListState[T]] } - /** - * Function to add the ListStateWithTTL schema to the list of column family schemas. - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @param ttlConfig - the ttl configuration (time to live duration etc.) - * @tparam T - type of state variable - * @return - instance of ListState of type T that can be used to store state persistently - */ override def getListState[T]( stateName: String, valEncoder: Encoder[T], @@ -389,17 +347,6 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi null.asInstanceOf[ListState[T]] } - /** - * Function to add the MapState schema to the list of column family schemas. - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * @param stateName - name of the state variable - * @param userKeyEnc - spark sql encoder for the map key - * @param valEncoder - spark sql encoder for the map value - * @tparam K - type of key for map state variable - * @tparam V - type of value for map state variable - * @return - instance of MapState of type [K,V] that can be used to store state persistently - */ override def getMapState[K, V]( stateName: String, userKeyEnc: Encoder[K], @@ -411,18 +358,6 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi null.asInstanceOf[MapState[K, V]] } - /** - * Function to add the MapStateWithTTL schema to the list of column family schemas. - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * @param stateName - name of the state variable - * @param userKeyEnc - spark sql encoder for the map key - * @param valEncoder - SQL encoder for state variable - * @param ttlConfig - the ttl configuration (time to live duration etc.) - * @tparam K - type of key for map state variable - * @tparam V - type of value for map state variable - * @return - instance of MapState of type [K,V] that can be used to store state persistently - */ override def getMapState[K, V]( stateName: String, userKeyEnc: Encoder[K], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 3d7e1900eebb8..ea275a28780ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -249,8 +249,7 @@ case class StreamingSymmetricHashJoinExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, - stateSchemaVersion: Int - ): Array[String] = { + stateSchemaVersion: Int): Array[String] = { var result: Map[String, (StructType, StructType)] = Map.empty // get state schema for state stores on left side of the join result ++= SymmetricHashJoinStateManager.getSchemaForStateStores(LeftSide, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 7ae8b5d1eb63e..f8cf480f830c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -98,7 +98,7 @@ case class TransformWithStateExec( * and fetch the schemas of the state variables initialized in this processor. * @return a new instance of the driver processor handle */ - private def getDriverProcessorHandle: DriverStatefulProcessorHandleImpl = { + private def getDriverProcessorHandle(): DriverStatefulProcessorHandleImpl = { val driverProcessorHandle = new DriverStatefulProcessorHandleImpl(timeMode, keyEncoder) driverProcessorHandle.setHandleState(StatefulProcessorHandleState.PRE_INIT) statefulProcessor.setHandle(driverProcessorHandle) @@ -111,12 +111,16 @@ case class TransformWithStateExec( * after init is called. */ private def getColFamilySchemas(): Map[String, ColumnFamilySchema] = { - val driverProcessorHandle = getDriverProcessorHandle - val columnFamilySchemas = driverProcessorHandle.getColumnFamilySchemas + val columnFamilySchemas = getDriverProcessorHandle().getColumnFamilySchemas closeProcessorHandle() columnFamilySchemas } + /** + * This method is used for the driver-side stateful processor after we + * have collected all the necessary schemas. + * This instance of the stateful processor won't be used again. + */ private def closeProcessorHandle(): Unit = { statefulProcessor.close() statefulProcessor.setHandle(null) @@ -373,12 +377,11 @@ case class TransformWithStateExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, - stateSchemaVersion: Int): - Array[String] = { + stateSchemaVersion: Int): Array[String] = { assert(stateSchemaVersion >= 3) val newColumnFamilySchemas = getColFamilySchemas() val schemaFile = new StateSchemaV3File( - hadoopConf, stateSchemaFilePath(StateStoreId.DEFAULT_STORE_NAME).toString) + hadoopConf, stateSchemaDirPath(StateStoreId.DEFAULT_STORE_NAME).toString) // TODO: Read the schema path from the OperatorStateMetadata file // and validate it with the new schema @@ -402,7 +405,7 @@ case class TransformWithStateExec( } } - private def stateSchemaFilePath(storeName: String): Path = { + private def stateSchemaDirPath(storeName: String): Path = { assert(storeName == StateStoreId.DEFAULT_STORE_NAME) def stateInfo = getStateInfo val stateCheckpointPath = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index 40c2b403b04e5..0a8021ab3de2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -73,7 +73,7 @@ object ColumnFamilySchemaV1 { s"Expected Map but got ${colFamilyMap.getClass}") val keySchema = StructType.fromString(colFamilyMap("keySchema").asInstanceOf[String]) val valueSchema = StructType.fromString(colFamilyMap("valueSchema").asInstanceOf[String]) - new ColumnFamilySchemaV1( + ColumnFamilySchemaV1( colFamilyMap("columnFamilyName").asInstanceOf[String], keySchema, valueSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index f7c58a4d4b752..8aabc0846fe61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -113,7 +113,6 @@ class StateSchemaCompatibilityChecker( object StateSchemaCompatibilityChecker extends Logging { val VERSION = 2 - /** * Function to check if new state store schema is compatible with the existing schema. * @param oldSchema - old state schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala index 5cee6eb807c4f..07a94400f30f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.streaming.MetadataVersionUtil.validateVers * The StateSchemaV3File is used to write the schema of multiple column families. * Right now, this is primarily used for the TransformWithState operator, which supports * multiple column families to keep the data for multiple state variables. + * We only expect ColumnFamilySchemaV1 to be written and read from this file. * @param hadoopConf Hadoop configuration that is used to read / write metadata files. * @param path Path to the directory that will be used for writing metadata. */ @@ -40,8 +41,6 @@ class StateSchemaV3File( hadoopConf: Configuration, path: String) { - val VERSION = 3 - val metadataPath = new Path(path) protected val fileManager: CheckpointFileManager = @@ -51,7 +50,7 @@ class StateSchemaV3File( fileManager.mkdirs(metadataPath) } - def deserialize(in: InputStream): List[ColumnFamilySchema] = { + private def deserialize(in: InputStream): List[ColumnFamilySchema] = { val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() if (!lines.hasNext) { @@ -59,13 +58,13 @@ class StateSchemaV3File( } val version = lines.next().trim - validateVersion(version, VERSION) + validateVersion(version, StateSchemaV3File.VERSION) lines.map(ColumnFamilySchemaV1.fromJson).toList } - def serialize(schemas: List[ColumnFamilySchema], out: OutputStream): Unit = { - out.write(s"v${VERSION}".getBytes(UTF_8)) + private def serialize(schemas: List[ColumnFamilySchema], out: OutputStream): Unit = { + out.write(s"v${StateSchemaV3File.VERSION}".getBytes(UTF_8)) out.write('\n') out.write(schemas.map(_.json).mkString("\n").getBytes(UTF_8)) } @@ -85,7 +84,6 @@ class StateSchemaV3File( protected def write( batchMetadataFile: Path, fn: OutputStream => Unit): Unit = { - // Only write metadata when the batch has not yet been written val output = fileManager.createAtomic(batchMetadataFile, overwriteIfPossible = false) try { fn(output) @@ -101,3 +99,7 @@ class StateSchemaV3File( new Path(metadataPath, batchId.toString) } } + +object StateSchemaV3File { + val VERSION = 3 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index a244e841de5a2..387d0aeaab24c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -296,7 +296,7 @@ object KeyStateEncoderSpec { NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA) case "RangeKeyScanStateEncoderSpec" => val orderingOrdinals = m("orderingOrdinals"). - asInstanceOf[List[_]].map(_.asInstanceOf[Int]) + asInstanceOf[List[_]].map(_.asInstanceOf[BigInt].toInt) RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) case "PrefixKeyScanStateEncoderSpec" => val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[BigInt] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 98b2030f1bac4..2c4111ec026ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -30,6 +30,8 @@ import scala.util.Random import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -1627,6 +1629,30 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] keyRow, keySchema, valueRow, keySchema, storeConf) } + test("test serialization and deserialization of NoPrefixKeyStateEncoderSpec") { + implicit val formats: DefaultFormats.type = DefaultFormats + val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema) + val jsonMap = JsonMethods.parse(encoderSpec.json).extract[Map[String, Any]] + val deserializedEncoderSpec = KeyStateEncoderSpec.fromJson(keySchema, jsonMap) + assert(encoderSpec == deserializedEncoderSpec) + } + + test("test serialization and deserialization of PrefixKeyScanStateEncoderSpec") { + implicit val formats: DefaultFormats.type = DefaultFormats + val encoderSpec = PrefixKeyScanStateEncoderSpec(keySchema, 1) + val jsonMap = JsonMethods.parse(encoderSpec.json).extract[Map[String, Any]] + val deserializedEncoderSpec = KeyStateEncoderSpec.fromJson(keySchema, jsonMap) + assert(encoderSpec == deserializedEncoderSpec) + } + + test("test serialization and deserialization of RangeKeyScanStateEncoderSpec") { + implicit val formats: DefaultFormats.type = DefaultFormats + val encoderSpec = RangeKeyScanStateEncoderSpec(keySchema, Seq(1)) + val jsonMap = JsonMethods.parse(encoderSpec.json).extract[Map[String, Any]] + val deserializedEncoderSpec = KeyStateEncoderSpec.fromJson(keySchema, jsonMap) + assert(encoderSpec == deserializedEncoderSpec) + } + /** Return a new provider with a random id */ def newStoreProvider(): ProviderClass From 0ad367954551194df645df5034994c3b4a3b4139 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 26 Jun 2024 17:00:35 -0700 Subject: [PATCH 06/22] creating operatorstatemetadata log --- .../state/metadata/StateMetadataSource.scala | 34 ++++++--- .../streaming/IncrementalExecution.scala | 12 ++-- .../streaming/OperatorStateMetadataLog.scala | 69 +++++++++++++++++++ .../StreamingSymmetricHashJoinExec.scala | 3 +- .../streaming/TransformWithStateExec.scala | 24 +++++++ .../state/OperatorStateMetadata.scala | 55 ++++++++++++++- .../streaming/statefulOperators.scala | 13 +++- .../state/OperatorStateMetadataSuite.scala | 2 +- 8 files changed, 190 insertions(+), 22 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 28de21aaf9389..19f504b8e8daa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.PATH -import org.apache.spark.sql.execution.streaming.CheckpointFileManager -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1} +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, OperatorStateMetadataLog} +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -46,6 +46,7 @@ case class StateMetadataTableEntry( numPartitions: Int, minBatchId: Long, maxBatchId: Long, + operatorPropertiesJson: String, numColsPrefixKey: Int) { def toRow(): InternalRow = { new GenericInternalRow( @@ -55,6 +56,7 @@ case class StateMetadataTableEntry( numPartitions, minBatchId, maxBatchId, + UTF8String.fromString(operatorPropertiesJson), numColsPrefixKey)) } } @@ -68,6 +70,7 @@ object StateMetadataTableEntry { .add("numPartitions", IntegerType) .add("minBatchId", LongType) .add("maxBatchId", LongType) + .add("operatorProperties", StringType) } } @@ -192,22 +195,35 @@ class StateMetadataPartitionReader( val stateDir = new Path(checkpointLocation, "state") val opIds = fileManager .list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted - opIds.map { opId => - new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read() + opIds.flatMap { opId => + val operatorIdPath = new Path(stateDir, opId.toString) + // check all OperatorStateMetadataV2 + val operatorStateMetadataV2Path = OperatorStateMetadataV2.metadataFilePath(operatorIdPath) + if (fileManager.exists(operatorStateMetadataV2Path)) { + val operatorStateMetadataLog = new OperatorStateMetadataLog( + hadoopConf, operatorStateMetadataV2Path.toString) + operatorStateMetadataLog.listBatchesOnDisk.flatMap(operatorStateMetadataLog.get) + } else { + Array(new OperatorStateMetadataReader(operatorIdPath, hadoopConf).read()) + } } } private[state] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { allOperatorStateMetadata.flatMap { operatorStateMetadata => - require(operatorStateMetadata.version == 1) - val operatorStateMetadataV1 = operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1] - operatorStateMetadataV1.stateStoreInfo.map { stateStoreMetadata => - StateMetadataTableEntry(operatorStateMetadataV1.operatorInfo.operatorId, - operatorStateMetadataV1.operatorInfo.operatorName, + require(operatorStateMetadata.version == 1 || operatorStateMetadata.version == 2) + val operatorProperties = operatorStateMetadata match { + case _: OperatorStateMetadataV1 => "" + case v2: OperatorStateMetadataV2 => v2.operatorPropertiesJson + } + operatorStateMetadata.stateStoreInfo.map { stateStoreMetadata => + StateMetadataTableEntry(operatorStateMetadata.operatorInfo.operatorId, + operatorStateMetadata.operatorInfo.operatorName, stateStoreMetadata.storeName, stateStoreMetadata.numPartitions, if (batchIds.nonEmpty) batchIds.head else -1, if (batchIds.nonEmpty) batchIds.last else -1, + operatorProperties, stateStoreMetadata.numColsPrefixKey ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index c65d35bb6c3d7..14af35d4da265 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataWriter} +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -456,11 +456,11 @@ class IncrementalExecution( new Path(checkpointLocation).getParent.toString, new SerializableConfiguration(hadoopConf)) val opMetadataList = reader.allOperatorStateMetadata - ret = opMetadataList.map { operatorMetadata => - val metadataInfoV1 = operatorMetadata - .asInstanceOf[OperatorStateMetadataV1] - .operatorInfo - metadataInfoV1.operatorId -> metadataInfoV1.operatorName + ret = opMetadataList.map { + case OperatorStateMetadataV1(operatorInfo, _) => + operatorInfo.operatorId -> operatorInfo.operatorName + case OperatorStateMetadataV2(operatorInfo, _, _) => + operatorInfo.operatorId -> operatorInfo.operatorName }.toMap } catch { case e: Exception => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala new file mode 100644 index 0000000000000..80e334be1101f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream} +import java.nio.charset.StandardCharsets +import java.nio.charset.StandardCharsets._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataOutputStream + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1, OperatorStateMetadataV2} +import org.apache.spark.sql.internal.SQLConf + + +class OperatorStateMetadataLog( + hadoopConf: Configuration, + path: String, + metadataCacheEnabled: Boolean = false) + extends HDFSMetadataLog[OperatorStateMetadata](hadoopConf, path, metadataCacheEnabled) { + + def this(sparkSession: SparkSession, path: String) = { + this( + sparkSession.sessionState.newHadoopConf(), + path, + metadataCacheEnabled = sparkSession.sessionState.conf.getConf( + SQLConf.STREAMING_METADATA_CACHE_ENABLED) + ) + } + + override protected def serialize(metadata: OperatorStateMetadata, out: OutputStream): Unit = { + val fsDataOutputStream = out.asInstanceOf[FSDataOutputStream] + fsDataOutputStream.write(s"v${metadata.version}\n".getBytes(StandardCharsets.UTF_8)) + metadata.version match { + case 1 => + OperatorStateMetadataV1.serialize(fsDataOutputStream, metadata) + case 2 => + OperatorStateMetadataV2.serialize(fsDataOutputStream, metadata) + } + } + + override protected def deserialize(in: InputStream): OperatorStateMetadata = { + // called inside a try-finally where the underlying stream is closed in the caller + // create buffered reader from input stream + val bufferedReader = new BufferedReader(new InputStreamReader(in, UTF_8)) + // read first line for version number, in the format "v{version}" + val version = bufferedReader.readLine() + version match { + case "v1" => OperatorStateMetadataV1.deserialize(bufferedReader) + case "v2" => OperatorStateMetadataV2.deserialize(bufferedReader) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index ea275a28780ef..153c5b4980b35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -230,7 +230,8 @@ case class StreamingSymmetricHashJoinExec( override def operatorStateMetadata(): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) - val stateStoreInfo = stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray + val stateStoreInfo: Array[StateStoreMetadata] = + stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray OperatorStateMetadataV1(operatorInfo, stateStoreInfo) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index f8cf480f830c1..f6efa112a4257 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -21,6 +21,10 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.JString +import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -93,6 +97,8 @@ case class TransformWithStateExec( } } + override def operatorStateMetadataVersion: Int = 2 + /** * We initialize this processor handle in the driver to run the init function * and fetch the schemas of the state variables initialized in this processor. @@ -416,6 +422,24 @@ case class TransformWithStateExec( new Path(new Path(storeNamePath, "_metadata"), "schema") } + /** Metadata of this stateful operator and its states stores. */ + override def operatorStateMetadata(): OperatorStateMetadata = { + val info = getStateInfo + val operatorInfo = OperatorInfoV1(info.operatorId, shortName) + // stateSchemaFilePath should be populated at this point + assert(info.stateSchemaPath.isDefined) + val stateStoreInfo: Array[StateStoreMetadata] = + Array(StateStoreMetadataV2( + StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions, info.stateSchemaPath.get)) + + val operatorPropertiesJson: JValue = + ("timeMode" -> JString(timeMode.toString)) ~ + ("outputMode" -> JString(outputMode.toString)) + + val json = compact(render(operatorPropertiesJson)) + OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) + } + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index 8ce883038401d..66832167b35d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, Metadata /** * Metadata for a state store instance. */ -trait StateStoreMetadata { +trait StateStoreMetadata extends Serializable { def storeName: String def numColsPrefixKey: Int def numPartitions: Int @@ -42,6 +42,21 @@ trait StateStoreMetadata { case class StateStoreMetadataV1(storeName: String, numColsPrefixKey: Int, numPartitions: Int) extends StateStoreMetadata +case class StateStoreMetadataV2( + storeName: String, + numColsPrefixKey: Int, + numPartitions: Int, + stateSchemaFilePath: String) + extends StateStoreMetadata + +object StateStoreMetadataV2 { + private implicit val formats: Formats = Serialization.formats(NoTypeHints) + + @scala.annotation.nowarn + private implicit val manifest = Manifest + .classType[StateStoreMetadataV2](implicitly[ClassTag[StateStoreMetadataV2]].runtimeClass) +} + /** * Information about a stateful operator. */ @@ -54,14 +69,25 @@ case class OperatorInfoV1(operatorId: Long, operatorName: String) extends Operat trait OperatorStateMetadata { def version: Int + + def operatorInfo: OperatorInfo + + def stateStoreInfo: Array[StateStoreMetadata] } case class OperatorStateMetadataV1( operatorInfo: OperatorInfoV1, - stateStoreInfo: Array[StateStoreMetadataV1]) extends OperatorStateMetadata { + stateStoreInfo: Array[StateStoreMetadata]) extends OperatorStateMetadata { override def version: Int = 1 } +case class OperatorStateMetadataV2( + operatorInfo: OperatorInfoV1, + stateStoreInfo: Array[StateStoreMetadata], + operatorPropertiesJson: String) extends OperatorStateMetadata { + override def version: Int = 2 +} + object OperatorStateMetadataV1 { private implicit val formats: Formats = Serialization.formats(NoTypeHints) @@ -84,6 +110,27 @@ object OperatorStateMetadataV1 { } } +object OperatorStateMetadataV2 { + private implicit val formats: Formats = Serialization.formats(NoTypeHints) + + @scala.annotation.nowarn + private implicit val manifest = Manifest + .classType[OperatorStateMetadataV2](implicitly[ClassTag[OperatorStateMetadataV2]].runtimeClass) + + def metadataFilePath(stateCheckpointPath: Path): Path = + new Path(new Path(new Path(stateCheckpointPath, "v2"), "_metadata"), "metadata") + + def deserialize(in: BufferedReader): OperatorStateMetadata = { + Serialization.read[OperatorStateMetadataV2](in) + } + + def serialize( + out: FSDataOutputStream, + operatorStateMetadata: OperatorStateMetadata): Unit = { + Serialization.write(operatorStateMetadata.asInstanceOf[OperatorStateMetadataV2], out) + } +} + /** * Write OperatorStateMetadata into the state checkpoint directory. */ @@ -114,7 +161,9 @@ class OperatorStateMetadataWriter(stateCheckpointPath: Path, hadoopConf: Configu } /** - * Read OperatorStateMetadata from the state checkpoint directory. + * Read OperatorStateMetadata from the state checkpoint directory. This class will only be + * used to read OperatorStateMetadataV1. + * OperatorStateMetadataV2 will be read by the OperatorStateMetadataLog. */ class OperatorStateMetadataReader(stateCheckpointPath: Path, hadoopConf: Configuration) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 94d976b568a5e..c175ec2534b3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import scala.jdk.CollectionConverters._ import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD @@ -73,6 +74,12 @@ trait StatefulOperator extends SparkPlan { } } + def metadataFilePath(): Path = { + val stateCheckpointPath = + new Path(getStateInfo.checkpointLocation, getStateInfo.operatorId.toString) + new Path(new Path(stateCheckpointPath, "_metadata"), "metadata") + } + // Function used to record state schema for the first time and validate it against proposed // schema changes in the future. Runs as part of a planning rule on the driver. // Returns the schema file path for operators that write this to the metadata file, @@ -141,6 +148,8 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp */ def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = Some(inputWatermarkMs) + def operatorStateMetadataVersion: Int = 1 + override lazy val metrics = statefulOperatorCustomMetrics ++ Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numRowsDroppedByWatermark" -> SQLMetrics.createMetric(sparkContext, @@ -192,7 +201,7 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp def operatorStateMetadata(): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) - val stateStoreInfo = + val stateStoreInfo: Array[StateStoreMetadata] = Array(StateStoreMetadataV1(StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions)) OperatorStateMetadataV1(operatorInfo, stateStoreInfo) } @@ -913,7 +922,7 @@ case class SessionWindowStateStoreSaveExec( override def operatorStateMetadata(): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) - val stateStoreInfo = Array(StateStoreMetadataV1( + val stateStoreInfo: Array[StateStoreMetadata] = Array(StateStoreMetadataV1( StateStoreId.DEFAULT_STORE_NAME, stateManager.getNumColsForPrefixKey, info.numPartitions)) OperatorStateMetadataV1(operatorInfo, stateStoreInfo) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala index dd8f7aab51dd0..b55bf63054ab8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala @@ -106,7 +106,7 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { StopStream ) - val expectedStateStoreInfo = Array( + val expectedStateStoreInfo: Array[StateStoreMetadata] = Array( StateStoreMetadataV1("left-keyToNumValues", 0, numShufflePartitions), StateStoreMetadataV1("left-keyWithIndexToValue", 0, numShufflePartitions), StateStoreMetadataV1("right-keyToNumValues", 0, numShufflePartitions), From ef86e37d1f45573ad2d32f8a11f31b49c1124417 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 27 Jun 2024 15:14:16 -0700 Subject: [PATCH 07/22] removing ': Array[StateStoreMetadata]' --- .../state/metadata/StateMetadataSource.scala | 39 ++++++++++++------- .../StreamingSymmetricHashJoinExec.scala | 2 +- .../streaming/TransformWithStateExec.scala | 2 +- .../state/OperatorStateMetadata.scala | 6 +-- .../streaming/statefulOperators.scala | 4 +- .../state/OperatorStateMetadataSuite.scala | 2 +- 6 files changed, 32 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 19f504b8e8daa..271f29ee07fce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -212,20 +212,31 @@ class StateMetadataPartitionReader( private[state] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { allOperatorStateMetadata.flatMap { operatorStateMetadata => require(operatorStateMetadata.version == 1 || operatorStateMetadata.version == 2) - val operatorProperties = operatorStateMetadata match { - case _: OperatorStateMetadataV1 => "" - case v2: OperatorStateMetadataV2 => v2.operatorPropertiesJson - } - operatorStateMetadata.stateStoreInfo.map { stateStoreMetadata => - StateMetadataTableEntry(operatorStateMetadata.operatorInfo.operatorId, - operatorStateMetadata.operatorInfo.operatorName, - stateStoreMetadata.storeName, - stateStoreMetadata.numPartitions, - if (batchIds.nonEmpty) batchIds.head else -1, - if (batchIds.nonEmpty) batchIds.last else -1, - operatorProperties, - stateStoreMetadata.numColsPrefixKey - ) + operatorStateMetadata match { + case v1: OperatorStateMetadataV1 => + v1.stateStoreInfo.map { stateStoreMetadata => + StateMetadataTableEntry(v1.operatorInfo.operatorId, + v1.operatorInfo.operatorName, + stateStoreMetadata.storeName, + stateStoreMetadata.numPartitions, + if (batchIds.nonEmpty) batchIds.head else -1, + if (batchIds.nonEmpty) batchIds.last else -1, + "", + stateStoreMetadata.numColsPrefixKey + ) + } + case v2: OperatorStateMetadataV2 => + v2.stateStoreInfo.map { stateStoreMetadata => + StateMetadataTableEntry(v2.operatorInfo.operatorId, + v2.operatorInfo.operatorName, + stateStoreMetadata.storeName, + stateStoreMetadata.numPartitions, + if (batchIds.nonEmpty) batchIds.head else -1, + if (batchIds.nonEmpty) batchIds.last else -1, + v2.operatorPropertiesJson, + stateStoreMetadata.numColsPrefixKey + ) + } } } }.iterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 153c5b4980b35..1e0df7eae6bb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -230,7 +230,7 @@ case class StreamingSymmetricHashJoinExec( override def operatorStateMetadata(): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) - val stateStoreInfo: Array[StateStoreMetadata] = + val stateStoreInfo = stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray OperatorStateMetadataV1(operatorInfo, stateStoreInfo) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index f6efa112a4257..3ff124572a52a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -428,7 +428,7 @@ case class TransformWithStateExec( val operatorInfo = OperatorInfoV1(info.operatorId, shortName) // stateSchemaFilePath should be populated at this point assert(info.stateSchemaPath.isDefined) - val stateStoreInfo: Array[StateStoreMetadata] = + val stateStoreInfo = Array(StateStoreMetadataV2( StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions, info.stateSchemaPath.get)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index 66832167b35d9..f68c9db1662c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -71,19 +71,17 @@ trait OperatorStateMetadata { def version: Int def operatorInfo: OperatorInfo - - def stateStoreInfo: Array[StateStoreMetadata] } case class OperatorStateMetadataV1( operatorInfo: OperatorInfoV1, - stateStoreInfo: Array[StateStoreMetadata]) extends OperatorStateMetadata { + stateStoreInfo: Array[StateStoreMetadataV1]) extends OperatorStateMetadata { override def version: Int = 1 } case class OperatorStateMetadataV2( operatorInfo: OperatorInfoV1, - stateStoreInfo: Array[StateStoreMetadata], + stateStoreInfo: Array[StateStoreMetadataV2], operatorPropertiesJson: String) extends OperatorStateMetadata { override def version: Int = 2 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c175ec2534b3b..37563579a5e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -201,7 +201,7 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp def operatorStateMetadata(): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) - val stateStoreInfo: Array[StateStoreMetadata] = + val stateStoreInfo = Array(StateStoreMetadataV1(StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions)) OperatorStateMetadataV1(operatorInfo, stateStoreInfo) } @@ -922,7 +922,7 @@ case class SessionWindowStateStoreSaveExec( override def operatorStateMetadata(): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) - val stateStoreInfo: Array[StateStoreMetadata] = Array(StateStoreMetadataV1( + val stateStoreInfo = Array(StateStoreMetadataV1( StateStoreId.DEFAULT_STORE_NAME, stateManager.getNumColsForPrefixKey, info.numPartitions)) OperatorStateMetadataV1(operatorInfo, stateStoreInfo) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala index b55bf63054ab8..dd8f7aab51dd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala @@ -106,7 +106,7 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { StopStream ) - val expectedStateStoreInfo: Array[StateStoreMetadata] = Array( + val expectedStateStoreInfo = Array( StateStoreMetadataV1("left-keyToNumValues", 0, numShufflePartitions), StateStoreMetadataV1("left-keyWithIndexToValue", 0, numShufflePartitions), StateStoreMetadataV1("right-keyToNumValues", 0, numShufflePartitions), From 06940c30917ac562193acd03f91638719dd4b2e5 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 27 Jun 2024 15:30:34 -0700 Subject: [PATCH 08/22] adding operatorProperties as a metadata column --- .../state/metadata/StateMetadataSource.scala | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 271f29ee07fce..2f27c4ebf4277 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -46,8 +46,8 @@ case class StateMetadataTableEntry( numPartitions: Int, minBatchId: Long, maxBatchId: Long, - operatorPropertiesJson: String, - numColsPrefixKey: Int) { + numColsPrefixKey: Int, + operatorPropertiesJson: String) { def toRow(): InternalRow = { new GenericInternalRow( Array[Any](operatorId, @@ -56,8 +56,8 @@ case class StateMetadataTableEntry( numPartitions, minBatchId, maxBatchId, - UTF8String.fromString(operatorPropertiesJson), - numColsPrefixKey)) + numColsPrefixKey, + UTF8String.fromString(operatorPropertiesJson))) } } @@ -70,7 +70,6 @@ object StateMetadataTableEntry { .add("numPartitions", IntegerType) .add("minBatchId", LongType) .add("maxBatchId", LongType) - .add("operatorProperties", StringType) } } @@ -113,7 +112,14 @@ class StateMetadataTable extends Table with SupportsRead with SupportsMetadataCo override def comment: String = "Number of columns in prefix key of the state store instance" } - override val metadataColumns: Array[MetadataColumn] = Array(NumColsPrefixKeyColumn) + private object OperatorPropertiesColumn extends MetadataColumn { + override def name: String = "_operatorProperties" + override def dataType: DataType = StringType + override def comment: String = "Json string storing operator properties" + } + + override val metadataColumns: Array[MetadataColumn] = + Array(NumColsPrefixKeyColumn, OperatorPropertiesColumn) } case class StateMetadataInputPartition(checkpointLocation: String) extends InputPartition @@ -221,8 +227,8 @@ class StateMetadataPartitionReader( stateStoreMetadata.numPartitions, if (batchIds.nonEmpty) batchIds.head else -1, if (batchIds.nonEmpty) batchIds.last else -1, - "", - stateStoreMetadata.numColsPrefixKey + stateStoreMetadata.numColsPrefixKey, + "" ) } case v2: OperatorStateMetadataV2 => @@ -233,8 +239,8 @@ class StateMetadataPartitionReader( stateStoreMetadata.numPartitions, if (batchIds.nonEmpty) batchIds.head else -1, if (batchIds.nonEmpty) batchIds.last else -1, - v2.operatorPropertiesJson, - stateStoreMetadata.numColsPrefixKey + stateStoreMetadata.numColsPrefixKey, + v2.operatorPropertiesJson ) } } From a668b777835c4cacf8bb9cc91cc92a0e2d7739d1 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 27 Jun 2024 16:15:17 -0700 Subject: [PATCH 09/22] changing the order of the metadata --- .../sql/execution/streaming/state/OperatorStateMetadata.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index f68c9db1662c8..47e5166c91188 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -116,7 +116,7 @@ object OperatorStateMetadataV2 { .classType[OperatorStateMetadataV2](implicitly[ClassTag[OperatorStateMetadataV2]].runtimeClass) def metadataFilePath(stateCheckpointPath: Path): Path = - new Path(new Path(new Path(stateCheckpointPath, "v2"), "_metadata"), "metadata") + new Path(new Path(new Path(stateCheckpointPath, "_metadata"), "metadata"), "v2") def deserialize(in: BufferedReader): OperatorStateMetadata = { Serialization.read[OperatorStateMetadataV2](in) From a057166853c673a08ac09a4d774ca1fdd3a91bf1 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 28 Jun 2024 09:46:14 -0700 Subject: [PATCH 10/22] tests pass --- .../v2/state/metadata/StateMetadataSource.scala | 17 ++++++++++------- .../streaming/IncrementalExecution.scala | 4 ++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 2f27c4ebf4277..50abb6d18b987 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -197,26 +197,29 @@ class StateMetadataPartitionReader( } else Array.empty } - private[sql] def allOperatorStateMetadata: Array[OperatorStateMetadata] = { + private[sql] def allOperatorStateMetadata: Array[(OperatorStateMetadata, Long)] = { val stateDir = new Path(checkpointLocation, "state") val opIds = fileManager .list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted opIds.flatMap { opId => val operatorIdPath = new Path(stateDir, opId.toString) - // check all OperatorStateMetadataV2 + // check if OperatorStateMetadataV2 path exists, if it does, read it + // otherwise, fall back to OperatorStateMetadataV1 val operatorStateMetadataV2Path = OperatorStateMetadataV2.metadataFilePath(operatorIdPath) if (fileManager.exists(operatorStateMetadataV2Path)) { val operatorStateMetadataLog = new OperatorStateMetadataLog( hadoopConf, operatorStateMetadataV2Path.toString) - operatorStateMetadataLog.listBatchesOnDisk.flatMap(operatorStateMetadataLog.get) + operatorStateMetadataLog.listBatchesOnDisk.flatMap { batchId => + operatorStateMetadataLog.get(batchId).map((_, batchId)) + } } else { - Array(new OperatorStateMetadataReader(operatorIdPath, hadoopConf).read()) + Array((new OperatorStateMetadataReader(operatorIdPath, hadoopConf).read(), -1L)) } } } private[state] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { - allOperatorStateMetadata.flatMap { operatorStateMetadata => + allOperatorStateMetadata.flatMap { case (operatorStateMetadata, batchId) => require(operatorStateMetadata.version == 1 || operatorStateMetadata.version == 2) operatorStateMetadata match { case v1: OperatorStateMetadataV1 => @@ -237,8 +240,8 @@ class StateMetadataPartitionReader( v2.operatorInfo.operatorName, stateStoreMetadata.storeName, stateStoreMetadata.numPartitions, - if (batchIds.nonEmpty) batchIds.head else -1, - if (batchIds.nonEmpty) batchIds.last else -1, + batchId, + batchId, stateStoreMetadata.numColsPrefixKey, v2.operatorPropertiesJson ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 14af35d4da265..091f9607dacba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -457,9 +457,9 @@ class IncrementalExecution( new SerializableConfiguration(hadoopConf)) val opMetadataList = reader.allOperatorStateMetadata ret = opMetadataList.map { - case OperatorStateMetadataV1(operatorInfo, _) => + case (OperatorStateMetadataV1(operatorInfo, _), _) => operatorInfo.operatorId -> operatorInfo.operatorName - case OperatorStateMetadataV2(operatorInfo, _, _) => + case (OperatorStateMetadataV2(operatorInfo, _, _), _) => operatorInfo.operatorId -> operatorInfo.operatorName }.toMap } catch { From 07ccd55407896f88fb19efd5635e4af73aec9a2d Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 28 Jun 2024 10:36:15 -0700 Subject: [PATCH 11/22] test case --- .../streaming/TransformWithStateSuite.scala | 179 +++++++++++++++++- 1 file changed, 178 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 10dc7e9c4f898..1bfa9211ee950 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkRuntimeException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, Encoders} +import org.apache.spark.sql.{Dataset, Encoders, Row} import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} @@ -804,6 +804,183 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + private def fetchColumnFamilySchemas( + checkpointDir: String, + operatorId: Int): List[ColumnFamilySchema] = { + fetchStateSchemaV3File(checkpointDir, operatorId).getLatest().get._2 + } + + private def fetchStateSchemaV3File( + checkpointDir: String, + operatorId: Int): StateSchemaV3File = { + val hadoopConf = spark.sessionState.newHadoopConf() + val stateChkptPath = new Path(checkpointDir, s"state/$operatorId") + val stateSchemaPath = new Path(new Path(stateChkptPath, "_metadata"), "schema") + new StateSchemaV3File(hadoopConf, stateSchemaPath.toString) + } + + test("transformWithState - verify StateSchemaV3 file is written correctly") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { chkptDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + + val columnFamilySchemas = fetchColumnFamilySchemas(chkptDir.getCanonicalPath, 0) + assert(columnFamilySchemas.length == 1) + + val expected = ColumnFamilySchemaV1( + "countState", + KEY_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + false, Encoders.scalaLong.schema, None + ) + val actual = columnFamilySchemas.head.asInstanceOf[ColumnFamilySchemaV1] + assert(expected == actual) + } + } + } + + test("transformWithState - verify StateSchemaV3 file is written correctly," + + " multiple column families") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { chkptDir => + val inputData = MemoryStream[(String, String)] + val result = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + StopStream + ) + + val columnFamilySchemas = fetchColumnFamilySchemas(chkptDir.getCanonicalPath, 0) + assert(columnFamilySchemas.length == 2) + + val expected = List( + ColumnFamilySchemaV1( + "countState", + KEY_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + false, + Encoders.scalaLong.schema, + None + ), + ColumnFamilySchemaV1( + "mostRecent", + KEY_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + false, + Encoders.STRING.schema, + None + ) + ) + val actual = columnFamilySchemas.map(_.asInstanceOf[ColumnFamilySchemaV1]) + assert(expected == actual) + } + } + } + + test("transformWithState - verify that StateSchemaV3 files are purged") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "1") { + withTempDir { chkptDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream, + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "2")), + StopStream, + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(), + StopStream + ) + // If the StateSchemaV3 files are not purged, there would be + // three files, but we should have only one file. + val batchesWithSchemaV3File = fetchStateSchemaV3File( + chkptDir.getCanonicalPath, 0).listBatchesOnDisk + assert(batchesWithSchemaV3File.length == 1) + // Make sure that only the latest batch has the schema file + assert(batchesWithSchemaV3File.head == 2) + } + } + } + + test("transformWithState - verify that OperatorStateMetadataV2" + + " file is being written correctly") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream, + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "2")), + StopStream + ) + + val df = spark.read.format("state-metadata").load(checkpointDir.toString) + checkAnswer(df, Seq( + Row(0, "transformWithStateExec", "default", 5, 0L, 0L), + Row(0, "transformWithStateExec", "default", 5, 1L, 1L) + )) + checkAnswer(df.select(df.metadataColumn("_operatorProperties")), + Seq( + Row("""{"timeMode":"NoTime","outputMode":"Update"}"""), + Row("""{"timeMode":"NoTime","outputMode":"Update"}""") + ) + ) + } + } + } + test("transformWithState - verify StateSchemaV3 serialization and deserialization" + " works with one batch") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> From 03265af43b1ba9b1566d7641cb829161d0d9b29d Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 1 Jul 2024 10:56:15 -0700 Subject: [PATCH 12/22] rebase --- .../streaming/StreamingSymmetricHashJoinExec.scala | 3 ++- .../sql/execution/streaming/TransformWithStateExec.scala | 6 +++--- .../spark/sql/execution/streaming/statefulOperators.scala | 6 ++++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 1e0df7eae6bb9..b3b8f66cd547e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -227,7 +227,8 @@ case class StreamingSymmetricHashJoinExec( private val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) - override def operatorStateMetadata(): OperatorStateMetadata = { + override def operatorStateMetadata( + stateSchemaPaths: Array[String] = Array.empty): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) val stateStoreInfo = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 3ff124572a52a..4a76acb21c78d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -423,14 +423,14 @@ case class TransformWithStateExec( } /** Metadata of this stateful operator and its states stores. */ - override def operatorStateMetadata(): OperatorStateMetadata = { + override def operatorStateMetadata( + stateSchemaPaths: Array[String] = Array.empty): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) // stateSchemaFilePath should be populated at this point - assert(info.stateSchemaPath.isDefined) val stateStoreInfo = Array(StateStoreMetadataV2( - StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions, info.stateSchemaPath.get)) + StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions, stateSchemaPaths.head)) val operatorPropertiesJson: JValue = ("timeMode" -> JString(timeMode.toString)) ~ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 37563579a5e75..437c57de8cfdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -198,7 +198,8 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2 /** Metadata of this stateful operator and its states stores. */ - def operatorStateMetadata(): OperatorStateMetadata = { + def operatorStateMetadata( + stateSchemaPaths: Array[String] = Array.empty): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) val stateStoreInfo = @@ -919,7 +920,8 @@ case class SessionWindowStateStoreSaveExec( keyWithoutSessionExpressions, getStateInfo, conf) :: Nil } - override def operatorStateMetadata(): OperatorStateMetadata = { + override def operatorStateMetadata( + stateSchemaPaths: Array[String] = Array.empty): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) val stateStoreInfo = Array(StateStoreMetadataV1( From 22b8b0a1b18db9289adf50e02726cb52d3114120 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 1 Jul 2024 17:09:54 -0700 Subject: [PATCH 13/22] Feedback --- .../spark/sql/execution/streaming/TransformWithStateExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 4a76acb21c78d..f03261b18e3e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -434,7 +434,7 @@ case class TransformWithStateExec( val operatorPropertiesJson: JValue = ("timeMode" -> JString(timeMode.toString)) ~ - ("outputMode" -> JString(outputMode.toString)) + ("outputMode" -> JString(outputMode.toString)) val json = compact(render(operatorPropertiesJson)) OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) From 37392bf9233c7fc44d3a4952ece39ae5eb69237b Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 2 Jul 2024 10:40:49 -0700 Subject: [PATCH 14/22] files written correctly --- .../execution/streaming/StreamExecution.scala | 18 ++++++++++++++++++ .../streaming/TransformWithStateExec.scala | 14 ++++++++++++++ .../streaming/TransformWithStateSuite.scala | 16 +++++++++++++++- 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 81f7acdb755bc..f9e38898f0e40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream} import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.execution.streaming.sources.ForeachBatchUserFuncException import org.apache.spark.sql.internal.SQLConf @@ -239,6 +240,23 @@ abstract class StreamExecution( */ val commitLog = new CommitLog(sparkSession, checkpointFile("commits")) + + lazy val operatorStateMetadataLogs: Map[Long, OperatorStateMetadataLog] = { + populateOperatorStateMetadatas(getLatestExecutionContext().executionPlan.executedPlan) + } + + private def populateOperatorStateMetadatas( + plan: SparkPlan): Map[Long, OperatorStateMetadataLog] = { + plan.flatMap { + case s: StateStoreWriter => s.stateInfo.map { info => + val metadataPath = s.metadataFilePath() + info.operatorId -> new OperatorStateMetadataLog(sparkSession, + metadataPath.toString) + } + case _ => Seq.empty + }.toMap + } + /** Whether all fields of the query have been initialized */ private def isInitialized: Boolean = state.get != INITIALIZING diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index f03261b18e3e5..34194e9bd52c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -440,6 +440,20 @@ case class TransformWithStateExec( OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) } + private def stateSchemaFilePath(storeName: Option[String] = None): Path = { + def stateInfo = getStateInfo + val stateCheckpointPath = + new Path(getStateInfo.checkpointLocation, + s"${stateInfo.operatorId.toString}") + storeName match { + case Some(storeName) => + val storeNamePath = new Path(stateCheckpointPath, storeName) + new Path(new Path(storeNamePath, "_metadata"), "schema") + case None => + new Path(new Path(stateCheckpointPath, "_metadata"), "schema") + } + } + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 1bfa9211ee950..f2753b52166f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -804,10 +804,24 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + private def fetchOperatorStateMetadataLog( + checkpointDir: String, + operatorId: Int): OperatorStateMetadataLog = { + val hadoopConf = spark.sessionState.newHadoopConf() + val stateChkptPath = new Path(checkpointDir, s"state/$operatorId") + val operatorStateMetadataPath = OperatorStateMetadataV2.metadataFilePath(stateChkptPath) + new OperatorStateMetadataLog(hadoopConf, operatorStateMetadataPath.toString) + } + private def fetchColumnFamilySchemas( checkpointDir: String, operatorId: Int): List[ColumnFamilySchema] = { - fetchStateSchemaV3File(checkpointDir, operatorId).getLatest().get._2 + val operatorStateMetadataLog = fetchOperatorStateMetadataLog(checkpointDir, operatorId) + val stateSchemaFilePath = operatorStateMetadataLog. + getLatest().get._2. + asInstanceOf[OperatorStateMetadataV2]. + stateStoreInfo.head.stateSchemaFilePath + fetchStateSchemaV3File(checkpointDir, operatorId).get(new Path(stateSchemaFilePath)) } private def fetchStateSchemaV3File( From cbbd47f01168fd789e02fda718f3c310074beb1d Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 2 Jul 2024 14:59:37 -0700 Subject: [PATCH 15/22] tests minus purging --- .../execution/streaming/StreamExecution.scala | 18 ------------------ .../streaming/TransformWithStateExec.scala | 14 -------------- .../streaming/statefulOperators.scala | 14 ++++++++++++++ 3 files changed, 14 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index f9e38898f0e40..81f7acdb755bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -40,7 +40,6 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream} import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write} -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.execution.streaming.sources.ForeachBatchUserFuncException import org.apache.spark.sql.internal.SQLConf @@ -240,23 +239,6 @@ abstract class StreamExecution( */ val commitLog = new CommitLog(sparkSession, checkpointFile("commits")) - - lazy val operatorStateMetadataLogs: Map[Long, OperatorStateMetadataLog] = { - populateOperatorStateMetadatas(getLatestExecutionContext().executionPlan.executedPlan) - } - - private def populateOperatorStateMetadatas( - plan: SparkPlan): Map[Long, OperatorStateMetadataLog] = { - plan.flatMap { - case s: StateStoreWriter => s.stateInfo.map { info => - val metadataPath = s.metadataFilePath() - info.operatorId -> new OperatorStateMetadataLog(sparkSession, - metadataPath.toString) - } - case _ => Seq.empty - }.toMap - } - /** Whether all fields of the query have been initialized */ private def isInitialized: Boolean = state.get != INITIALIZING diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 34194e9bd52c6..f03261b18e3e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -440,20 +440,6 @@ case class TransformWithStateExec( OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) } - private def stateSchemaFilePath(storeName: Option[String] = None): Path = { - def stateInfo = getStateInfo - val stateCheckpointPath = - new Path(getStateInfo.checkpointLocation, - s"${stateInfo.operatorId.toString}") - storeName match { - case Some(storeName) => - val storeNamePath = new Path(stateCheckpointPath, storeName) - new Path(new Path(storeNamePath, "_metadata"), "schema") - case None => - new Path(new Path(stateCheckpointPath, "_metadata"), "schema") - } - } - override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 437c57de8cfdc..492ba37865d42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -165,6 +165,20 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp "number of state store instances") ) ++ stateStoreCustomMetrics ++ pythonMetrics + def stateSchemaFilePath(storeName: Option[String] = None): Path = { + def stateInfo = getStateInfo + val stateCheckpointPath = + new Path(getStateInfo.checkpointLocation, + s"${stateInfo.operatorId.toString}") + storeName match { + case Some(storeName) => + val storeNamePath = new Path(stateCheckpointPath, storeName) + new Path(new Path(storeNamePath, "_metadata"), "schema") + case None => + new Path(new Path(stateCheckpointPath, "_metadata"), "schema") + } + } + /** * Get the progress made by this stateful operator after execution. This should be called in * the driver after this SparkPlan has been executed and metrics have been updated. From 81e1fb1f0f846559bda119c2c16284673765579f Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 2 Jul 2024 16:16:56 -0700 Subject: [PATCH 16/22] tests pass --- .../streaming/IncrementalExecution.scala | 18 ++++-- .../streaming/TransformWithStateSuite.scala | 58 ++++--------------- 2 files changed, 24 insertions(+), 52 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 091f9607dacba..a0ace912faf7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -211,10 +211,20 @@ class IncrementalExecution( statefulOp match { case stateStoreWriter: StateStoreWriter => val metadata = stateStoreWriter.operatorStateMetadata() - // TODO: Populate metadata with stateSchemaPaths if metadata version is v2 - val metadataWriter = new OperatorStateMetadataWriter(new Path( - checkpointLocation, stateStoreWriter.getStateInfo.operatorId.toString), hadoopConf) - metadataWriter.write(metadata) + stateStoreWriter match { + case tws: TransformWithStateExec => + val metadataPath = OperatorStateMetadataV2.metadataFilePath(new Path( + checkpointLocation, tws.getStateInfo.operatorId.toString)) + val operatorStateMetadataLog = new OperatorStateMetadataLog(sparkSession, + metadataPath.toString) + operatorStateMetadataLog.add(currentBatchId, metadata) + case _ => + val metadataWriter = new OperatorStateMetadataWriter(new Path( + checkpointLocation, + stateStoreWriter.getStateInfo.operatorId.toString), + hadoopConf) + metadataWriter.write(metadata) + } case _ => } statefulOp diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index f2753b52166f1..d44cd39a2b48c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.{Dataset, Encoders, Row} import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, POJOTestClass, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreMultipleColumnFamiliesNotSupportedException, TestClass} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchema, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, OperatorStateMetadataV2, POJOTestClass, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreMultipleColumnFamiliesNotSupportedException, TestClass} import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -821,7 +821,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest getLatest().get._2. asInstanceOf[OperatorStateMetadataV2]. stateStoreInfo.head.stateSchemaFilePath - fetchStateSchemaV3File(checkpointDir, operatorId).get(new Path(stateSchemaFilePath)) + fetchStateSchemaV3File(checkpointDir, operatorId).getWithPath(new Path(stateSchemaFilePath)) } private def fetchStateSchemaV3File( @@ -859,8 +859,9 @@ class TransformWithStateSuite extends StateStoreMetricsTest val expected = ColumnFamilySchemaV1( "countState", KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), - false, Encoders.scalaLong.schema, None + None ) val actual = columnFamilySchemas.head.asInstanceOf[ColumnFamilySchemaV1] assert(expected == actual) @@ -902,12 +903,12 @@ class TransformWithStateSuite extends StateStoreMetricsTest None ), ColumnFamilySchemaV1( - "mostRecent", - KEY_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), - false, - Encoders.STRING.schema, - None + "mostRecent", + KEY_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + false, + Encoders.STRING.schema, + None ) ) val actual = columnFamilySchemas.map(_.asInstanceOf[ColumnFamilySchemaV1]) @@ -916,45 +917,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - verify that StateSchemaV3 files are purged") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, - SQLConf.MIN_BATCHES_TO_RETAIN.key -> "1") { - withTempDir { chkptDir => - val inputData = MemoryStream[String] - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), - StopStream, - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "2")), - StopStream, - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, "a"), - CheckNewAnswer(), - StopStream - ) - // If the StateSchemaV3 files are not purged, there would be - // three files, but we should have only one file. - val batchesWithSchemaV3File = fetchStateSchemaV3File( - chkptDir.getCanonicalPath, 0).listBatchesOnDisk - assert(batchesWithSchemaV3File.length == 1) - // Make sure that only the latest batch has the schema file - assert(batchesWithSchemaV3File.head == 2) - } - } - } - test("transformWithState - verify that OperatorStateMetadataV2" + " file is being written correctly") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> From 99609ee6d561009835e861d33b8828bca0f5d712 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 3 Jul 2024 09:32:41 -0700 Subject: [PATCH 17/22] tests pass --- .../state/OperatorStateMetadata.scala | 6 ++--- .../streaming/TransformWithStateSuite.scala | 25 +++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index 47e5166c91188..a319f93fe82b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -47,7 +47,7 @@ case class StateStoreMetadataV2( numColsPrefixKey: Int, numPartitions: Int, stateSchemaFilePath: String) - extends StateStoreMetadata + extends StateStoreMetadata with Serializable object StateStoreMetadataV2 { private implicit val formats: Formats = Serialization.formats(NoTypeHints) @@ -60,14 +60,14 @@ object StateStoreMetadataV2 { /** * Information about a stateful operator. */ -trait OperatorInfo { +trait OperatorInfo extends Serializable { def operatorId: Long def operatorName: String } case class OperatorInfoV1(operatorId: Long, operatorName: String) extends OperatorInfo -trait OperatorStateMetadata { +trait OperatorStateMetadata extends Serializable { def version: Int def operatorInfo: OperatorInfo diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index d44cd39a2b48c..3ef407667e5f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -957,6 +957,31 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + test("transformWithState - verify OperatorStateMetadataV2 serialization and deserialization" + + " works") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val metadata = OperatorStateMetadataV2( + OperatorInfoV1(0, "transformWithStateExec"), + Array(StateStoreMetadataV2("default", 0, 0, "path")), + """{"timeMode":"NoTime","outputMode":"Update"}""") + val metadataLog = new OperatorStateMetadataLog( + spark.sessionState.newHadoopConf(), + checkpointDir.getCanonicalPath) + metadataLog.add(0, metadata) + assert(metadataLog.get(0).isDefined) + // assert that each of the fields are the same + val metadataV2 = metadataLog.get(0).get.asInstanceOf[OperatorStateMetadataV2] + assert(metadataV2.operatorInfo == metadata.operatorInfo) + assert(metadataV2.stateStoreInfo.sameElements(metadata.stateStoreInfo)) + assert(metadataV2.operatorPropertiesJson == metadata.operatorPropertiesJson) + } + } + } + test("transformWithState - verify StateSchemaV3 serialization and deserialization" + " works with one batch") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> From 77ffe9531c082ee6a47d2a62f86779ff4d957690 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 3 Jul 2024 15:02:30 -0700 Subject: [PATCH 18/22] hdfsmetadatalog --- .../execution/streaming/HDFSMetadataLog.scala | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 251cc16acdf43..97589b83ec3cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.json4s.{Formats, NoTypeHints} import org.json4s.jackson.Serialization @@ -47,10 +48,25 @@ import org.apache.spark.util.ArrayImplicits._ * * Note: [[HDFSMetadataLog]] doesn't support S3-like file systems as they don't guarantee listing * files in a directory always shows the latest files. + * @param hadoopConf Hadoop configuration that is used to read / write metadata files. + * @param path Path to the directory that will be used for writing metadata. + * @param metadataCacheEnabled Whether to cache the batches' metadata in memory. */ -class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: String) +class HDFSMetadataLog[T <: AnyRef : ClassTag]( + hadoopConf: Configuration, + path: String, + val metadataCacheEnabled: Boolean = false) extends MetadataLog[T] with Logging { + def this(sparkSession: SparkSession, path: String) = { + this( + sparkSession.sessionState.newHadoopConf(), + path, + metadataCacheEnabled = sparkSession.sessionState.conf.getConf( + SQLConf.STREAMING_METADATA_CACHE_ENABLED) + ) + } + private implicit val formats: Formats = Serialization.formats(NoTypeHints) /** Needed to serialize type T into JSON when using Jackson */ @@ -64,15 +80,12 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: val metadataPath = new Path(path) protected val fileManager = - CheckpointFileManager.create(metadataPath, sparkSession.sessionState.newHadoopConf()) + CheckpointFileManager.create(metadataPath, hadoopConf) if (!fileManager.exists(metadataPath)) { fileManager.mkdirs(metadataPath) } - protected val metadataCacheEnabled: Boolean - = sparkSession.sessionState.conf.getConf(SQLConf.STREAMING_METADATA_CACHE_ENABLED) - /** * Cache the latest two batches. [[StreamExecution]] usually just accesses the latest two batches * when committing offsets, this cache will save some file system operations. From 6c90c9f035fd4cf65653c8348f52ae22216c83a2 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 3 Jul 2024 18:44:14 -0700 Subject: [PATCH 19/22] feedback --- .../spark/sql/execution/streaming/IncrementalExecution.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a0ace912faf7d..54ae908869718 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -85,6 +85,7 @@ class IncrementalExecution( .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) + private val STATE_SCHEMA_DEFAULT_VERSION: Int = 2 /** * This value dictates which schema format version the state schema should be written in * for all operators other than TransformWithState. From b63859276cfaa48b81f2dbc1ebbfe712352a6b9c Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 8 Jul 2024 11:32:14 -0700 Subject: [PATCH 20/22] checking the OperatorStateMetadata log for the state schema file --- .../streaming/IncrementalExecution.scala | 1 + .../streaming/OperatorStateMetadataLog.scala | 2 ++ .../streaming/TransformWithStateExec.scala | 22 +++++++++++++++++-- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 54ae908869718..dc10bd5f43521 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -214,6 +214,7 @@ class IncrementalExecution( val metadata = stateStoreWriter.operatorStateMetadata() stateStoreWriter match { case tws: TransformWithStateExec => + logError(s"### checkpointLocation: $checkpointLocation") val metadataPath = OperatorStateMetadataV2.metadataFilePath(new Path( checkpointLocation, tws.getStateInfo.operatorId.toString)) val operatorStateMetadataLog = new OperatorStateMetadataLog(sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala index 80e334be1101f..5b9846d30aa37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala @@ -51,6 +51,8 @@ class OperatorStateMetadataLog( case 1 => OperatorStateMetadataV1.serialize(fsDataOutputStream, metadata) case 2 => + logError(s"### stateSchemaPath: ${metadata.asInstanceOf[OperatorStateMetadataV2]. + stateStoreInfo.head.stateSchemaFilePath}") OperatorStateMetadataV2.serialize(fsDataOutputStream, metadata) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index f03261b18e3e5..42dea5d7c2452 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -380,19 +380,37 @@ case class TransformWithStateExec( ) } + private def fetchOperatorStateMetadataLog( + hadoopConf: Configuration, + checkpointDir: String, + operatorId: Long): OperatorStateMetadataLog = { + val checkpointPath = new Path(checkpointDir, operatorId.toString) + val operatorStateMetadataPath = OperatorStateMetadataV2.metadataFilePath(checkpointPath) + new OperatorStateMetadataLog(hadoopConf, operatorStateMetadataPath.toString) + } + override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): Array[String] = { assert(stateSchemaVersion >= 3) - val newColumnFamilySchemas = getColFamilySchemas() + val newSchemas = getColFamilySchemas() val schemaFile = new StateSchemaV3File( hadoopConf, stateSchemaDirPath(StateStoreId.DEFAULT_STORE_NAME).toString) // TODO: Read the schema path from the OperatorStateMetadata file // and validate it with the new schema + val operatorStateMetadataLog = fetchOperatorStateMetadataLog( + hadoopConf, getStateInfo.checkpointLocation, getStateInfo.operatorId) + val mostRecentLog = operatorStateMetadataLog.getLatest() + val oldSchemas = mostRecentLog.map(_._2.asInstanceOf[OperatorStateMetadataV2]) + .map(_.stateStoreInfo.map(_.stateSchemaFilePath)).getOrElse(Array.empty) + .flatMap { schemaPath => + schemaFile.getWithPath(new Path(schemaPath)) + }.toList + validateSchemas(oldSchemas, newSchemas) // Write the new schema to the schema file - val schemaPath = schemaFile.addWithUUID(batchId, newColumnFamilySchemas.values.toList) + val schemaPath = schemaFile.addWithUUID(batchId, newSchemas.values.toList) Array(schemaPath.toString) } From 58e194788a8c4a91ce2059ce33c9dabfcfd4672f Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 8 Jul 2024 11:48:38 -0700 Subject: [PATCH 21/22] adding todo --- .../streaming/IncrementalExecution.scala | 1 - .../streaming/OperatorStateMetadataLog.scala | 2 - .../streaming/TransformWithStateSuite.scala | 62 +++++++++++++++++++ 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index dc10bd5f43521..54ae908869718 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -214,7 +214,6 @@ class IncrementalExecution( val metadata = stateStoreWriter.operatorStateMetadata() stateStoreWriter match { case tws: TransformWithStateExec => - logError(s"### checkpointLocation: $checkpointLocation") val metadataPath = OperatorStateMetadataV2.metadataFilePath(new Path( checkpointLocation, tws.getStateInfo.operatorId.toString)) val operatorStateMetadataLog = new OperatorStateMetadataLog(sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala index 5b9846d30aa37..80e334be1101f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala @@ -51,8 +51,6 @@ class OperatorStateMetadataLog( case 1 => OperatorStateMetadataV1.serialize(fsDataOutputStream, metadata) case 2 => - logError(s"### stateSchemaPath: ${metadata.asInstanceOf[OperatorStateMetadataV2]. - stateStoreInfo.head.stateSchemaFilePath}") OperatorStateMetadataV2.serialize(fsDataOutputStream, metadata) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 3ef407667e5f3..22ffc5d8367e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -64,6 +64,32 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S } } +class RunningCountStatefulProcessorInt extends StatefulProcessor[String, String, (String, String)] + with Logging { + @transient protected var _countState: ValueState[Int] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[Int]("countState", Encoders.scalaInt) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + val count = _countState.getOption().getOrElse(0) + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(count) + Iterator((key, count.toString)) + } + } +} + // Class to verify stateful processor usage with adding processing time timers class RunningCountStatefulProcessorWithProcTimeTimer extends RunningCountStatefulProcessor { private def handleProcessingTimeBasedTimers( @@ -1054,6 +1080,42 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } } + + // TODO: Enable this test and expect error to be thrown when + // github.com/apache/spark/pull/47257 is merged + ignore("test that invalid schema evolution fails query for column family") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorInt(), + TimeMode.None(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "2")), + StopStream + ) + } + } + } } class TransformWithStateValidationSuite extends StateStoreMetricsTest { From 6ff37f41615d275280f4736efe9e30acd207db94 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 9 Jul 2024 10:43:50 -0700 Subject: [PATCH 22/22] removing println, test passes --- .../streaming/IncrementalExecution.scala | 3 +- .../streaming/state/StateSchemaV3File.scala | 4 +- .../streaming/TransformWithStateSuite.scala | 60 +++++++++++-------- .../TransformWithValueStateTTLSuite.scala | 1 - 4 files changed, 39 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 54ae908869718..3fe5aeae5f637 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -85,7 +85,6 @@ class IncrementalExecution( .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) - private val STATE_SCHEMA_DEFAULT_VERSION: Int = 2 /** * This value dictates which schema format version the state schema should be written in * for all operators other than TransformWithState. @@ -211,7 +210,7 @@ class IncrementalExecution( // write out the state schema paths to the metadata file statefulOp match { case stateStoreWriter: StateStoreWriter => - val metadata = stateStoreWriter.operatorStateMetadata() + val metadata = stateStoreWriter.operatorStateMetadata(stateSchemaPaths) stateStoreWriter match { case tws: TransformWithStateExec => val metadataPath = OperatorStateMetadataV2.metadataFilePath(new Path( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala index 07a94400f30f4..a73bf6ff5a320 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala @@ -50,7 +50,7 @@ class StateSchemaV3File( fileManager.mkdirs(metadataPath) } - private def deserialize(in: InputStream): List[ColumnFamilySchema] = { + private[sql] def deserialize(in: InputStream): List[ColumnFamilySchema] = { val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() if (!lines.hasNext) { @@ -63,7 +63,7 @@ class StateSchemaV3File( lines.map(ColumnFamilySchemaV1.fromJson).toList } - private def serialize(schemas: List[ColumnFamilySchema], out: OutputStream): Unit = { + private[sql] def serialize(schemas: List[ColumnFamilySchema], out: OutputStream): Unit = { out.write(s"v${StateSchemaV3File.VERSION}".getBytes(UTF_8)) out.write('\n') out.write(schemas.map(_.json).mkString("\n").getBytes(UTF_8)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 22ffc5d8367e9..7954ee510ca1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -27,8 +27,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Encoders, Row} import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchema, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, OperatorStateMetadataV2, POJOTestClass, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreMultipleColumnFamiliesNotSupportedException, TestClass} +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchema, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, OperatorInfoV1, OperatorStateMetadataV2, POJOTestClass, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreMetadataV2, StateStoreMultipleColumnFamiliesNotSupportedException, StateStoreValueSchemaNotCompatible, TestClass} import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -881,11 +881,12 @@ class TransformWithStateSuite extends StateStoreMetricsTest val columnFamilySchemas = fetchColumnFamilySchemas(chkptDir.getCanonicalPath, 0) assert(columnFamilySchemas.length == 1) - val expected = ColumnFamilySchemaV1( "countState", - KEY_ROW_SCHEMA, - VALUE_ROW_SCHEMA, + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType, false)), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), None ) @@ -922,18 +923,20 @@ class TransformWithStateSuite extends StateStoreMetricsTest val expected = List( ColumnFamilySchemaV1( "countState", - KEY_ROW_SCHEMA, + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType, false)), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), - false, - Encoders.scalaLong.schema, None ), ColumnFamilySchemaV1( "mostRecent", - KEY_ROW_SCHEMA, + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", StringType, true)), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), - false, - Encoders.STRING.schema, None ) ) @@ -1017,8 +1020,10 @@ class TransformWithStateSuite extends StateStoreMetricsTest withTempDir { checkpointDir => val schema = List(ColumnFamilySchemaV1( "countState", - KEY_ROW_SCHEMA, - VALUE_ROW_SCHEMA, + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType, false)), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), None )) @@ -1042,8 +1047,10 @@ class TransformWithStateSuite extends StateStoreMetricsTest val schema0 = List(ColumnFamilySchemaV1( "countState", - KEY_ROW_SCHEMA, - VALUE_ROW_SCHEMA, + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType, false)), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), None )) @@ -1051,15 +1058,19 @@ class TransformWithStateSuite extends StateStoreMetricsTest val schema1 = List( ColumnFamilySchemaV1( "countState", - KEY_ROW_SCHEMA, - VALUE_ROW_SCHEMA, + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType, false)), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), None ), ColumnFamilySchemaV1( "mostRecent", - KEY_ROW_SCHEMA, - VALUE_ROW_SCHEMA, + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", StringType, false)), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), None ) @@ -1081,9 +1092,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - // TODO: Enable this test and expect error to be thrown when - // github.com/apache/spark/pull/47257 is merged - ignore("test that invalid schema evolution fails query for column family") { + test("test that invalid schema evolution fails query for column family") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1110,8 +1119,11 @@ class TransformWithStateSuite extends StateStoreMetricsTest testStream(result2, OutputMode.Update())( StartStream(checkpointLocation = checkpointDir.getCanonicalPath), AddData(inputData, "a"), - CheckNewAnswer(("a", "2")), - StopStream + ExpectFailure[StateStoreValueSchemaNotCompatible] { + (t: Throwable) => { + assert(t.getMessage.contains("Please check number and type of fields.")) + } + } ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 3f7a0be5759c3..d4d730dac9996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -319,7 +319,6 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { AdvanceManualClock(1 * 1000), CheckNewAnswer(), Execute { q => - println("last progress:" + q.lastProgress) val schemaFilePath = fm.list(stateSchemaPath).toSeq.head.getPath val ssv3 = new StateSchemaV3File(hadoopConf, new Path(checkpointDir.toString, metadataPathPostfix).toString)