From 884ae6cf3a819358bad6976787819a526d4513b4 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Sun, 17 Mar 2024 21:28:07 -0700 Subject: [PATCH 01/21] Add support for ValueState TTL. --- .../main/resources/error/error-classes.json | 5 + ...ditions-unsupported-feature-error-class.md | 4 + .../apache/spark/sql/streaming/TTLMode.java | 51 +++++++ .../sql/catalyst/plans/logical/TTLMode.scala | 26 ++++ .../streaming/StatefulProcessorHandle.scala | 11 +- .../spark/sql/streaming/ValueState.scala | 4 +- .../sql/catalyst/plans/logical/object.scala | 7 +- .../spark/sql/KeyValueGroupedDataset.scala | 6 +- .../spark/sql/execution/SparkStrategies.scala | 9 +- .../execution/streaming/MapStateImpl.scala | 2 +- .../streaming/StateTypesEncoderUtils.scala | 42 +++++- .../streaming/StateVariableTTLSupport.scala | 116 +++++++++++++++ .../StatefulProcessorHandleImpl.scala | 31 +++- .../execution/streaming/TimerStateImpl.scala | 8 +- .../streaming/TransformWithStateExec.scala | 42 +++--- .../execution/streaming/ValueStateImpl.scala | 94 ++++++++++--- .../execution/streaming/state/RocksDB.scala | 4 + .../state/RocksDBStateStoreProvider.scala | 7 + .../streaming/state/StateStore.scala | 5 + .../streaming/state/StateStoreErrors.scala | 14 +- .../streaming/state/ListStateSuite.scala | 16 ++- .../streaming/state/MapStateSuite.scala | 11 +- .../state/StatefulProcessorHandleSuite.scala | 15 +- .../streaming/state/ValueStateSuite.scala | 27 ++-- .../TransformWithListStateSuite.scala | 8 ++ .../TransformWithMapStateSuite.scala | 5 + .../streaming/TransformWithStateSuite.scala | 133 ++++++++++++++++++ 27 files changed, 611 insertions(+), 92 deletions(-) create mode 100644 sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 11c8204d2c93..bf9628da733e 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -4353,6 +4353,11 @@ "Removing column families with is not supported." ] }, + "STATE_STORE_TTL" : { + "message" : [ + "State TTL with is not supported. Please use RocksDBStateStoreProvider." + ] + }, "TABLE_OPERATION" : { "message" : [ "Table does not support . Please check the current catalog and namespace to make sure the qualified table name is expected, and also check the catalog implementation which is configured by \"spark.sql.catalog\"." diff --git a/docs/sql-error-conditions-unsupported-feature-error-class.md b/docs/sql-error-conditions-unsupported-feature-error-class.md index e580ecc63b18..f67d7caff63d 100644 --- a/docs/sql-error-conditions-unsupported-feature-error-class.md +++ b/docs/sql-error-conditions-unsupported-feature-error-class.md @@ -202,6 +202,10 @@ Creating multiple column families with `` is not supported. Removing column families with `` is not supported. +## STATE_STORE_TTL + +State TTL with `` is not supported. Please use RocksDBStateStoreProvider. + ## TABLE_OPERATION Table `` does not support ``. Please check the current catalog and namespace to make sure the qualified table name is expected, and also check the catalog implementation which is configured by "spark.sql.catalog". diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java new file mode 100644 index 000000000000..210f4b78eb84 --- /dev/null +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.catalyst.plans.logical.*; + +/** + * Represents the type of ttl modes possible for user defined state + * in [[StatefulProcessor]]. + */ +@Experimental +@Evolving +public class TTLMode { + + /** + * Specifies that there is no TTL for the state object. Such objects would not + * be cleaned up by Spark automatically. + */ + public static final TTLMode NoTTL() { + return NoTTL$.MODULE$; + } + + /** + * Specifies that the specified ttl is in processing time. + */ + public static final TTLMode ProcessingTimeTTL() { + return ProcessingTimeTTL$.MODULE$; + } + + /** + * Specifies that the specified ttl is in event time. + */ + public static final TTLMode EventTimeTTL() { return EventTimeTTL$.MODULE$; } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala new file mode 100644 index 000000000000..e0e02868fbf4 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.streaming.TTLMode + +/** Types of timeouts used in tranformWithState operator */ +case object NoTTL extends TTLMode + +case object ProcessingTimeTTL extends TTLMode + +case object EventTimeTTL extends TTLMode diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala index 560188a0ff62..10a914e11247 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala @@ -30,15 +30,20 @@ import org.apache.spark.sql.Encoder private[sql] trait StatefulProcessorHandle extends Serializable { /** - * Function to create new or return existing single value state variable of given type + * Function to create new or return existing single value state variable of given type. + * The state will be eventually cleaned up after the specified ttl. * The user must ensure to call this function only within the `init()` method of the * StatefulProcessor. - * @param stateName - name of the state variable + * + * @param stateName - name of the state variable * @param valEncoder - SQL encoder for state variable + * @param ttlMode - ttl mode for the state * @tparam T - type of state variable * @return - instance of ValueState of type T that can be used to store state persistently */ - def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] + def getValueState[T]( + stateName: String, + valEncoder: Encoder[T]): ValueState[T] /** * Creates new or returns existing list state associated with stateName. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala index 9c707c8308ab..36ee12fa83e1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import java.io.Serializable +import java.time.Duration import org.apache.spark.annotation.{Evolving, Experimental} @@ -43,7 +44,8 @@ private[sql] trait ValueState[S] extends Serializable { def getOption(): Option[S] /** Update the value of the state. */ - def update(newState: S): Unit + // TODO(sahnib) confirm if this should be scala or Java type of Duration + def update(newState: S, ttlDuration: Duration = Duration.ZERO): Unit /** Remove this state. */ def clear(): Unit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index b2c443a8cce0..ff7c8fb3df4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode, TTLMode} import org.apache.spark.sql.types._ object CatalystSerde { @@ -574,6 +574,7 @@ object TransformWithState { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[K, V, U], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, child: LogicalPlan): LogicalPlan = { @@ -584,6 +585,7 @@ object TransformWithState { groupingAttributes, dataAttributes, statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]], + ttlMode, timeoutMode, outputMode, keyEncoder.asInstanceOf[ExpressionEncoder[Any]], @@ -605,6 +607,7 @@ object TransformWithState { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[K, V, U], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, child: LogicalPlan, @@ -618,6 +621,7 @@ object TransformWithState { groupingAttributes, dataAttributes, statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]], + ttlMode, timeoutMode, outputMode, keyEncoder.asInstanceOf[ExpressionEncoder[Any]], @@ -639,6 +643,7 @@ case class TransformWithState( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 95ad973aee51..f9758e4cbf98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.internal.TypedAggUtils -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode, TTLMode} /** * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -662,6 +662,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode = OutputMode.Append()): Dataset[U] = { Dataset[U]( sparkSession, @@ -669,6 +670,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, child = logicalPlan @@ -693,6 +695,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( */ private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { @@ -702,6 +705,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, child = logicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cc212d99f299..e124c6f2edc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -751,7 +751,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case TransformWithState( keyDeserializer, valueDeserializer, groupingAttributes, - dataAttributes, statefulProcessor, timeoutMode, outputMode, + dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputAttr, child, hasInitialState, initialStateGroupingAttrs, initialStateDataAttrs, initialStateDeserializer, initialState) => @@ -761,6 +761,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, keyEncoder, @@ -925,15 +926,15 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { hasInitialState, planLater(initialState), planLater(child) ) :: Nil case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes, - dataAttributes, statefulProcessor, timeoutMode, outputMode, keyEncoder, + dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputObjAttr, child, hasInitialState, initialStateGroupingAttrs, initialStateDataAttrs, initialStateDeserializer, initialState) => TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, valueDeserializer, - groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode, + groupingAttributes, dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputObjAttr, planLater(child), hasInitialState, initialStateGroupingAttrs, initialStateDataAttrs, - initialStateDeserializer, planLater(initialState)) :: Nil + initialStateDeserializer, planLater (initialState)) :: Nil case _: FlatMapGroupsInPandasWithState => // TODO(SPARK-40443): support applyInPandasWithState in batch query diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index d2ccd0a77807..c58f32ed756d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -45,7 +45,7 @@ class MapStateImpl[K, V]( /** Whether state exists or not. */ override def exists(): Boolean = { - !store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).isEmpty + store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).nonEmpty } /** Get the state value if it exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index 1d41db896cdf..f5509abfc52f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -23,11 +23,19 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.state.StateStoreErrors -import org.apache.spark.sql.types.{BinaryType, StructType} +import org.apache.spark.sql.types.{BinaryType, LongType, StructType} object StateKeyValueRowSchema { val KEY_ROW_SCHEMA: StructType = new StructType().add("key", BinaryType) - val VALUE_ROW_SCHEMA: StructType = new StructType().add("value", BinaryType) + val VALUE_ROW_SCHEMA: StructType = new StructType() + .add("value", BinaryType) + .add("ttlExpirationMs", LongType) + + def encodeGroupingKeyBytes(keyBytes: Array[Byte]): UnsafeRow = { + val keyProjection = UnsafeProjection.create(KEY_ROW_SCHEMA) + val keyRow = keyProjection(InternalRow(keyBytes)) + keyRow + } } /** @@ -65,30 +73,50 @@ class StateTypesEncoder[GK, V]( // TODO: validate places that are trying to encode the key and check if we can eliminate/ // add caching for some of these calls. def encodeGroupingKey(): UnsafeRow = { + val keyRow = keyProjection(InternalRow(serializeGroupingKey())) + keyRow + } + + def encodeSerializedGroupingKey( + groupingKeyBytes: Array[Byte]): UnsafeRow = { + val keyRow = keyProjection(InternalRow(groupingKeyBytes)) + keyRow + } + + def serializeGroupingKey(): Array[Byte] = { val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption if (keyOption.isEmpty) { throw StateStoreErrors.implicitKeyNotFound(stateName) } - val groupingKey = keyOption.get.asInstanceOf[GK] - val keyByteArr = keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() - val keyRow = keyProjection(InternalRow(keyByteArr)) - keyRow + keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() } def encodeValue(value: V): UnsafeRow = { val objRow: InternalRow = objToRowSerializer.apply(value) val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() - val valRow = valueProjection(InternalRow(bytes)) + val valRow = valueProjection(InternalRow(bytes, 0L)) valRow } + def encodeValue(value: V, expirationMs: Long = -1): UnsafeRow = { + val objRow: InternalRow = objToRowSerializer.apply(value) + val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() + val valueRow = valueProjection(InternalRow(bytes, expirationMs)) + valueRow + } + def decodeValue(row: UnsafeRow): V = { val bytes = row.getBinary(0) reusedValRow.pointTo(bytes, bytes.length) val value = rowToObjDeserializer.apply(reusedValRow) value } + + def decodeTtlExpirationMs(row: UnsafeRow): Long = { + val expirationMs = row.getLong(1) + expirationMs + } } object StateTypesEncoder { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala new file mode 100644 index 000000000000..836fff647a2b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.time.Duration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.streaming.TTLMode +import org.apache.spark.sql.types.{BinaryType, DataType, LongType, NullType, StructField, StructType} + +object StateTTLSchema { + val KEY_ROW_SCHEMA: StructType = new StructType() + .add("expirationMs", LongType) + .add("groupingKey", BinaryType) + val VALUE_ROW_SCHEMA: StructType = + StructType(Array(StructField("__dummy__", NullType))) +} + +trait StateVariableTTLSupport { + def clearIfExpired(groupingKey: Array[Byte]): Unit +} + +trait TTLState { + def clearExpiredState(): Unit +} + +class SingleKeyTTLState( + ttlMode: TTLMode, + stateName: String, + store: StateStore, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long], + state: StateVariableTTLSupport) + extends TTLState { + + import org.apache.spark.sql.execution.streaming.StateTTLSchema._ + + private val ttlColumnFamilyName = s"_ttl_$stateName" + private val ttlKeyEncoder = UnsafeProjection.create(KEY_ROW_SCHEMA) + + // empty row used for values + private val EMPTY_ROW = + UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) + + store.createColFamilyIfAbsent(ttlColumnFamilyName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), isInternal = true) + + def upsertTTLForStateKey( + expirationMs: Long, + groupingKey: Array[Byte]): Unit = { + val encodedTtlKey = ttlKeyEncoder(InternalRow(expirationMs, groupingKey)) + store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName) + } + + override def clearExpiredState(): Unit = { + store.iterator(ttlColumnFamilyName).foreach { kv => + val expirationMs = kv.key.getLong(0) + val isExpired = StateTTL.isExpired(ttlMode, expirationMs, + batchTimestampMs, eventTimeWatermarkMs) + + if (isExpired) { + val groupingKey = kv.key.getBinary(1) + state.clearIfExpired(groupingKey) + + // TODO(sahnib) ; validate its safe to update inside iterator + store.remove(kv.key, ttlColumnFamilyName) + } + } + } +} + +object StateTTL { + def calculateExpirationTimeForDuration( + ttlMode: TTLMode, + ttlDuration: Duration, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Long = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get + ttlDuration.toMillis + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get + ttlDuration.toMillis + } else { + -1L + } + } + + def isExpired( + ttlMode: TTLMode, + expirationMs: Long, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Boolean = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get > expirationMs + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get > expirationMs + } else { + false + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 5f3b794fd117..367203fc7f21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.streaming +import java.util import java.util.UUID import org.apache.spark.TaskContext @@ -24,7 +25,7 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TimeoutMode, TTLMode, ValueState} import org.apache.spark.util.Utils /** @@ -77,14 +78,19 @@ class StatefulProcessorHandleImpl( store: StateStore, runId: UUID, keyEncoder: ExpressionEncoder[Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, - isStreaming: Boolean = true) + isStreaming: Boolean = true, + batchTimestampMs: Option[Long] = None, + eventTimeWatermarkMs: Option[Long] = None) extends StatefulProcessorHandle with Logging { import StatefulProcessorHandleState._ + private val ttlStates: util.List[TTLState] = new util.ArrayList[TTLState]() + private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" - private def buildQueryInfo(): QueryInfo = { + private def buildQueryInfo(): QueryInfo = { val taskCtxOpt = Option(TaskContext.get()) val (queryId, batchId) = if (!isStreaming) { (BATCH_QUERY_ID, 0L) @@ -115,9 +121,16 @@ class StatefulProcessorHandleImpl( def getHandleState: StatefulProcessorHandleState = currState - override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { - verifyStateVarOperations("get_value_state") - val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) + override def getValueState[T]( + stateName: String, + valEncoder: Encoder[T]): ValueState[T] = { + verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " + + "initialization is complete") + + val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder, + ttlMode, batchTimestampMs, eventTimeWatermarkMs) + resultState.ttlState.foreach(ttlStates.add(_)) + resultState } @@ -185,6 +198,12 @@ class StatefulProcessorHandleImpl( timerState.listTimers() } + def doTtlCleanup(): Unit = { + ttlStates.forEach { s => + s.clearExpiredState() + } + } + /** * Function to delete and purge state variable if defined previously * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala index af321eecb4db..8d410b677c84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -78,25 +78,25 @@ class TimerStateImpl( private val secIndexKeyEncoder = UnsafeProjection.create(keySchemaForSecIndex) - val timerCFName = if (timeoutMode == TimeoutMode.ProcessingTime) { + private val timerCFName = if (timeoutMode == TimeoutMode.ProcessingTime) { TimerStateUtils.PROC_TIMERS_STATE_NAME } else { TimerStateUtils.EVENT_TIMERS_STATE_NAME } - val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF + private val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF store.createColFamilyIfAbsent(keyToTsCFName, schemaForKeyRow, schemaForValueRow, PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1), useMultipleValuesPerKey = false, isInternal = true) - val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF + private val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex, schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, 1), useMultipleValuesPerKey = false, isInternal = true) private def getGroupingKey(cfName: String): Any = { val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption - if (!keyOption.isDefined) { + if (keyOption.isEmpty) { throw StateStoreErrors.implicitKeyNotFound(cfName) } keyOption.get diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 36b957f9d430..8371fdec5e8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -27,11 +27,11 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.streaming._ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Utils} /** @@ -42,6 +42,7 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Uti * @param groupingAttributes used to group the data * @param dataAttributes used to read the data * @param statefulProcessor processor methods called on underlying data + * @param ttlMode defines the ttl Mode * @param timeoutMode defines the timeout mode * @param outputMode defines the output mode for the statefulProcessor * @param keyEncoder expression encoder for the key type @@ -58,6 +59,7 @@ case class TransformWithStateExec( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], @@ -102,10 +104,6 @@ case class TransformWithStateExec( override def keyExpressions: Seq[Attribute] = groupingAttributes - protected val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) - - protected val schemaForValueRow: StructType = new StructType().add("value", BinaryType) - /** * Distribute by grouping attributes - We need the underlying data and the initial state data * to have the same grouping so that the data are co-located on the same task. @@ -283,6 +281,8 @@ case class TransformWithStateExec( allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) commitTimeMs += timeTakenMs { if (isStreaming) { + // join ttlBackgroundThread forkjoinpool + processorHandle.doTtlCleanup() store.commit() } else { store.abort() @@ -331,9 +331,9 @@ case class TransformWithStateExec( val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId) val store = StateStore.get( storeProviderId = storeProviderId, - keySchema = schemaForKeyRow, - valueSchema = schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), version = stateInfo.get.storeVersion, useColumnFamilies = true, storeConf = storeConf, @@ -351,9 +351,9 @@ case class TransformWithStateExec( if (isStreaming) { child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, - schemaForKeyRow, - schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), session.sqlContext.sessionState, Some(session.sqlContext.streams.stateStoreCoordinator), useColumnFamilies = true @@ -401,9 +401,9 @@ case class TransformWithStateExec( // Create StateStoreProvider for this partition val stateStoreProvider = StateStoreProvider.createAndInit( providerId, - schemaForKeyRow, - schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useColumnFamilies = true, storeConf = storeConf, hadoopConf = hadoopConfBroadcast.value.value, @@ -426,7 +426,8 @@ case class TransformWithStateExec( private def processData(store: StateStore, singleIterator: Iterator[InternalRow]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl( - store, getStateInfo.queryRunId, keyEncoder, timeoutMode, isStreaming) + store, getStateInfo.queryRunId, keyEncoder, ttlMode, timeoutMode, + isStreaming, batchTimestampMs, eventTimeWatermarkForEviction) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) statefulProcessor.init(outputMode, timeoutMode) @@ -440,7 +441,7 @@ case class TransformWithStateExec( initStateIterator: Iterator[InternalRow]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, - keyEncoder, timeoutMode, isStreaming) + keyEncoder, ttlMode, timeoutMode, isStreaming) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) statefulProcessor.init(outputMode, timeoutMode) @@ -463,7 +464,7 @@ case class TransformWithStateExec( } } -// scalastyle:off +// scalastyle:off argcount object TransformWithStateExec { // Plan logical transformWithState for batch queries @@ -473,6 +474,7 @@ object TransformWithStateExec { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], @@ -498,6 +500,7 @@ object TransformWithStateExec { groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, keyEncoder, @@ -515,4 +518,5 @@ object TransformWithStateExec { initialState) } } -// scalastyle:on +// scalastyle:on argcount + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 08876ca3032e..b09dd43c30b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -16,39 +16,59 @@ */ package org.apache.spark.sql.execution.streaming +import java.time.Duration + import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} -import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.streaming.{TTLMode, ValueState} /** * Class that provides a concrete implementation for a single value state associated with state * variables used in the streaming transformWithState operator. * @param store - reference to the StateStore instance to be used for storing state * @param stateName - name of logical state partition - * @param keyEnc - Spark SQL encoder for key + * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value + * @param ttlMode ttl Mode to evict expired values from state + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermarkMs event time watermark for state eviction * @tparam S - data type of object that will be stored */ class ValueStateImpl[S]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], - valEncoder: Encoder[S]) extends ValueState[S] with Logging { + valEncoder: Encoder[S], + ttlMode: TTLMode, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]) + extends ValueState[S] + with Logging + with StateVariableTTLSupport { private val keySerializer = keyExprEnc.createSerializer() - private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) + private[sql] var ttlState: Option[SingleKeyTTLState] = None + + initialize() - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + private def initialize(): Unit = { + store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + + if (ttlMode != TTLMode.NoTTL()) { + val _ttlState = new SingleKeyTTLState(ttlMode, stateName, store, + batchTimestampMs, eventTimeWatermarkMs, this) + ttlState = Some(_ttlState) + } + } /** Function to check if state exists. Returns true if present and false otherwise */ override def exists(): Boolean = { - getImpl() != null + get() != null } /** Function to return Option of value if exists and None otherwise */ @@ -58,26 +78,66 @@ class ValueStateImpl[S]( /** Function to return associated value with key if exists and null otherwise */ override def get(): S = { - val retRow = getImpl() + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + if (retRow != null) { - stateTypesEncoder.decodeValue(retRow) + val resState = stateTypesEncoder.decodeValue(retRow) + + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(retRow) + val isExpired = StateTTL.isExpired(ttlMode, + expirationMs, batchTimestampMs, eventTimeWatermarkMs) + + if (!isExpired) { + resState + } else { + null.asInstanceOf[S] + } } else { null.asInstanceOf[S] } } - private def getImpl(): UnsafeRow = { - store.get(stateTypesEncoder.encodeGroupingKey(), stateName) - } - /** Function to update and overwrite state associated with given key */ - override def update(newState: S): Unit = { - store.put(stateTypesEncoder.encodeGroupingKey(), - stateTypesEncoder.encodeValue(newState), stateName) + override def update( + newState: S, + ttlDuration: Duration = Duration.ZERO): Unit = { + + if (ttlDuration != Duration.ZERO && ttlState.isEmpty) { + // TODO(sahnib) throw a StateStoreError here + throw new RuntimeException() + } + + var expirationMs: Long = -1 + if (ttlDuration != Duration.ZERO) { + expirationMs = StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } + + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey), + stateTypesEncoder.encodeValue(newState, expirationMs), stateName) + + ttlState.foreach(_.upsertTTLForStateKey(expirationMs, serializedGroupingKey)) } /** Function to remove state for given key */ override def clear(): Unit = { store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) } + + def clearIfExpired(groupingKey: Array[Byte]): Unit = { + val encodedGroupingKey = stateTypesEncoder.encodeSerializedGroupingKey(groupingKey) + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(retRow) + val isExpired = StateTTL.isExpired(ttlMode, + expirationMs, batchTimestampMs, eventTimeWatermarkMs) + + if (!isExpired) { + store.remove(encodedGroupingKey, stateName) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 1817104a5c22..1c1159f65454 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -342,6 +342,10 @@ class RocksDB( } } + def listColumnFamilies(): Seq[String] = { + Seq() + } + /** * Remove RocksDB column family, if exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index c8537f2a6a5b..c6db9b5d000c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -257,6 +257,13 @@ private[sql] class RocksDBStateStoreProvider rocksDB.removeColFamilyIfExists(colFamilyName) keyValueEncoderMap.remove(colFamilyName) } + + /** Return a list of column family names */ + override def listColumnFamilies(): Seq[String] = { + verify(useColumnFamilies, "Column families are not supported in this store") + // turn keys of keyValueEncoderMap to Seq + rocksDB.listColumnFamilies() + } } override def init( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index dd97aa5b9afc..c79c244dcf44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -122,6 +122,11 @@ trait StateStore extends ReadStateStore { */ def removeColFamilyIfExists(colFamilyName: String): Unit + /** + * Return list of column family names. + */ + def listColumnFamilies(): Seq[String] = Seq(StateStore.DEFAULT_COL_FAMILY_NAME) + /** * Create column family with given name, if absent. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 2f72cbb0b0fc..e86670355cae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -49,6 +49,11 @@ object StateStoreErrors { new StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider) } + def ttlNotSupportedWithProvider(stateStoreProvider: String): + StateStoreTTLNotSupportedException = { + new StateStoreTTLNotSupportedException(stateStoreProvider) + } + def removingColumnFamiliesNotSupported(stateStoreProvider: String): StateStoreRemovingColumnFamiliesNotSupportedException = { new StateStoreRemovingColumnFamiliesNotSupportedException(stateStoreProvider) @@ -122,7 +127,14 @@ object StateStoreErrors { class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_MULTIPLE_COLUMN_FAMILIES", - messageParameters = Map("stateStoreProvider" -> stateStoreProvider)) + messageParameters = Map("stateStoreProvider" -> stateStoreProvider) + ) + +class StateStoreTTLNotSupportedException(stateStoreProvider: String) + extends SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_TTL", + messageParameters = Map("stateStoreProvider" -> stateStoreProvider) + ) class StateStoreRemovingColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala index e895e475b74d..51cfc1548b39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} -import org.apache.spark.sql.streaming.{ListState, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, TimeoutMode, TTLMode, ValueState} /** * Class that adds unit tests for ListState types used in arbitrary stateful @@ -37,7 +37,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val listState: ListState[Long] = handle.getListState[Long]("listState", Encoders.scalaLong) @@ -47,7 +48,7 @@ class ListStateSuite extends StateVariableSuiteBase { } checkError( - exception = e.asInstanceOf[SparkIllegalArgumentException], + exception = e, errorClass = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE", sqlState = Some("42601"), parameters = Map("stateName" -> "listState") @@ -70,7 +71,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ListState[Long] = handle.getListState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") @@ -98,7 +100,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: ListState[Long] = handle.getListState[Long]("testState1", Encoders.scalaLong) val testState2: ListState[Long] = handle.getListState[Long]("testState2", Encoders.scalaLong) @@ -136,7 +139,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val listState1: ListState[Long] = handle.getListState[Long]("listState1", Encoders.scalaLong) val listState2: ListState[Long] = handle.getListState[Long]("listState2", Encoders.scalaLong) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala index ce72061d39ea..7fa41b12795e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala @@ -22,7 +22,7 @@ import java.util.UUID import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} -import org.apache.spark.sql.streaming.{ListState, MapState, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, MapState, TimeoutMode, TTLMode, ValueState} import org.apache.spark.sql.types.{BinaryType, StructType} /** @@ -39,7 +39,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: MapState[String, Double] = handle.getMapState[String, Double]("testState", Encoders.STRING, Encoders.scalaDouble) @@ -73,7 +74,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: MapState[Long, Double] = handle.getMapState[Long, Double]("testState1", Encoders.scalaLong, Encoders.scalaDouble) @@ -112,7 +114,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val mapTestState1: MapState[String, Int] = handle.getMapState[String, Int]("mapTestState1", Encoders.STRING, Encoders.scalaInt) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala index 662a5dbfaac4..aec828459fce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} -import org.apache.spark.sql.streaming.TimeoutMode +import org.apache.spark.sql.streaming.{TimeoutMode, TTLMode} + /** * Class that adds tests to verify operations based on stateful processor handle @@ -48,7 +49,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) assert(handle.getHandleState === StatefulProcessorHandleState.CREATED) handle.getValueState[Long]("testState", Encoders.scalaLong) } @@ -89,7 +90,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) Seq(StatefulProcessorHandleState.INITIALIZED, StatefulProcessorHandleState.DATA_PROCESSED, @@ -107,7 +108,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, TimeoutMode.NoTimeouts()) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val ex = intercept[SparkUnsupportedOperationException] { handle.registerTimer(10000L) } @@ -143,7 +144,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) handle.setHandleState(StatefulProcessorHandleState.INITIALIZED) assert(handle.getHandleState === StatefulProcessorHandleState.INITIALIZED) @@ -164,7 +165,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) handle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED) assert(handle.getHandleState === StatefulProcessorHandleState.DATA_PROCESSED) @@ -204,7 +205,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) Seq(StatefulProcessorHandleState.CREATED, StatefulProcessorHandleState.TIMER_PROCESSED, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 8668b58672c7..8063c2cdb155 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{TimeoutMode, TTLMode, ValueState} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -48,7 +48,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val stateName = "testState" val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -78,7 +79,7 @@ class ValueStateSuite extends StateVariableSuiteBase { testState.update(123) } checkError( - ex.asInstanceOf[SparkException], + ex1.asInstanceOf[SparkException], errorClass = "INTERNAL_ERROR_TWS", parameters = Map( "message" -> s"Implicit key not found in state store for stateName=$stateName" @@ -92,7 +93,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") @@ -118,7 +120,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: ValueState[Long] = handle.getValueState[Long]( "testState1", Encoders.scalaLong) @@ -164,7 +167,7 @@ class ValueStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], - TimeoutMode.NoTimeouts()) + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val cfName = "_testState" val ex = intercept[SparkUnsupportedOperationException] { @@ -204,7 +207,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Double] = handle.getValueState[Double]("testState", Encoders.scalaDouble) @@ -230,7 +234,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -256,7 +261,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[TestClass] = handle.getValueState[TestClass]("testState", Encoders.product[TestClass]) @@ -282,7 +288,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[POJOTestClass] = handle.getValueState[POJOTestClass]("testState", Encoders.bean(classOf[POJOTestClass])) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 95ab34d40131..51cc9ff87890 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -140,6 +140,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update()) ( @@ -160,6 +161,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -180,6 +182,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -200,6 +203,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -220,6 +224,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -240,6 +245,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -260,6 +266,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update()) ( @@ -312,6 +319,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x) .transformWithState(new ToggleSaveAndEmitProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index db8cb8b810af..3c1280137af9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -95,6 +95,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) @@ -121,6 +122,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -145,6 +147,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -168,6 +171,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) testStream(result, OutputMode.Append())( // Test exists() @@ -222,6 +226,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) val df = result.toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 24e68e3db9d8..8893fa118cf3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -34,6 +34,35 @@ object TransformWithStateSuiteUtils { val NUM_SHUFFLE_PARTITIONS = 5 } +class RunningCountStatefulProcessorZeroTTL + extends StatefulProcessor[String, String, (String, String)] + with Logging { + @transient private var _countState: ValueState[Long] = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode) : Unit = { + assert(getHandle.getQueryInfo().getBatchId >= 0) + _countState = getHandle.getValueState("countState", Encoders.scalaLong) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + val count = _countState.getOption().getOrElse(0L) + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(count) + Iterator((key, count.toString)) + } + } + + override def close(): Unit = {} +} class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)] with Logging { @transient protected var _countState: ValueState[Long] = _ @@ -231,6 +260,7 @@ class RunningCountMostRecentStatefulProcessor _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) _mostRecent = getHandle.getValueState[String]("mostRecent", Encoders.STRING) } + override def handleInputRows( key: String, inputRows: Iterator[(String, String)], @@ -249,6 +279,41 @@ class RunningCountMostRecentStatefulProcessor } } +class RunningCountMostRecentStatefulProcessorWithTTL + extends StatefulProcessor[String, (String, String), (String, String, String)] + with Logging { + @transient private var _countState: ValueState[Long] = _ + @transient private var _mostRecent: ValueState[String] = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode) : Unit = { + assert(getHandle.getQueryInfo().getBatchId >= 0) + _countState = getHandle.getValueState( + "countState", Encoders.scalaLong) + _mostRecent = getHandle.getValueState("mostRecent", Encoders.STRING) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, String)], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, String)] = { + val count = _countState.getOption().getOrElse(0L) + 1 + val mostRecent = _mostRecent.getOption().getOrElse("") + + var output = List[(String, String, String)]() + inputRows.foreach { row => + _mostRecent.update(row._2) + _countState.update(count) + output = (key, count.toString, mostRecent) :: output + } + output.iterator + } + + override def close(): Unit = {} +} + class MostRecentStatefulProcessorWithDeletion extends StatefulProcessor[String, (String, String), (String, String)] with Logging { @@ -310,6 +375,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessorWithError(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -331,6 +397,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -361,6 +428,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -404,6 +472,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new RunningCountStatefulProcessorWithProcTimeTimerUpdates(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -440,6 +509,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new RunningCountStatefulProcessorWithMultipleTimers(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -475,6 +545,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new MaxEventTimeStatefulProcessor(), TimeoutMode.EventTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -516,6 +587,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) val df = result.toDF() @@ -534,12 +606,14 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x._1) .transformWithState(new RunningCountMostRecentStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) val stream2 = inputData.toDS() .groupByKey(x => x._1) .transformWithState(new MostRecentStatefulProcessorWithDeletion(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(stream1, OutputMode.Update())( @@ -559,6 +633,61 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + test("transformWithState - Zero duration TTL, should expire immediately") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[String] + + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorZeroTTL(), + TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), + OutputMode.Update()) + + // State should expire immediately, meaning each answer is independent + // of previous counts + testStream(result, OutputMode.Update())( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "1"), ("b", "1")), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "1"), ("b", "1")), + StopStream + ) + } + } + + test("transformWithState - multiple state variables with one TTL") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[(String, String)] + val stream1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessorWithTTL(), + TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), + OutputMode.Update()) + + // State should expire immediately, meaning each answer is independent + // of previous counts + testStream(stream1, OutputMode.Update())( + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + AddData(inputData, ("a", "str2"), ("b", "str3")), + CheckNewAnswer(("a", "1", "str1"), ("b", "1", "")), + AddData(inputData, ("a", "str4"), ("b", "str5")), + CheckNewAnswer(("a", "1", "str2"), ("b", "1", "str3")), + StopStream + ) + } + } + test("transformWithState - two input streams") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -572,6 +701,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -605,6 +735,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -638,6 +769,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -760,6 +892,7 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest { .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( From 2c311e209dd127c919c76e2254334f60bc5d6e78 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Fri, 22 Mar 2024 13:09:12 -0700 Subject: [PATCH 02/21] Fix existing testcases, and add TTL testcase. --- .../streaming/StateVariableTTLSupport.scala | 13 +- .../StatefulProcessorHandleImpl.scala | 11 +- .../streaming/TransformWithStateExec.scala | 1 + .../execution/streaming/ValueStateImpl.scala | 2 +- .../streaming/TransformWithStateSuite.scala | 119 ------------------ .../TransformWithStateTTLSuite.scala | 89 +++++++++++++ 6 files changed, 109 insertions(+), 126 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala index 836fff647a2b..07fa8c141b15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming import java.time.Duration +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} @@ -45,14 +46,15 @@ class SingleKeyTTLState( stateName: String, store: StateStore, batchTimestampMs: Option[Long], - eventTimeWatermarkMs: Option[Long], - state: StateVariableTTLSupport) - extends TTLState { + eventTimeWatermarkMs: Option[Long]) + extends TTLState + with Logging { import org.apache.spark.sql.execution.streaming.StateTTLSchema._ private val ttlColumnFamilyName = s"_ttl_$stateName" private val ttlKeyEncoder = UnsafeProjection.create(KEY_ROW_SCHEMA) + private var state: StateVariableTTLSupport = _ // empty row used for values private val EMPTY_ROW = @@ -83,6 +85,11 @@ class SingleKeyTTLState( } } } + + private[sql] def setStateVariable( + state: StateVariableTTLSupport): Unit = { + this.state = state + } } object StateTTL { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 367203fc7f21..b74d09232421 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -89,6 +89,7 @@ class StatefulProcessorHandleImpl( private val ttlStates: util.List[TTLState] = new util.ArrayList[TTLState]() private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" + logInfo(s"Created StatefulProcessorHandle") private def buildQueryInfo(): QueryInfo = { val taskCtxOpt = Option(TaskContext.get()) @@ -124,12 +125,16 @@ class StatefulProcessorHandleImpl( override def getValueState[T]( stateName: String, valEncoder: Encoder[T]): ValueState[T] = { - verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " + - "initialization is complete") + verifyStateVarOperations("get_value_state") val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder, ttlMode, batchTimestampMs, eventTimeWatermarkMs) - resultState.ttlState.foreach(ttlStates.add(_)) + val ttlState = resultState.ttlState + + ttlState.foreach { s => + ttlStates.add(s) + s.setStateVariable(resultState) + } resultState } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 8371fdec5e8c..30de84d550bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -299,6 +299,7 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver + // TODO(sahnib) add validation for ttlMode timeoutMode match { case ProcessingTime => if (batchTimestampMs.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index b09dd43c30b8..7cfc873ef3a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -61,7 +61,7 @@ class ValueStateImpl[S]( if (ttlMode != TTLMode.NoTTL()) { val _ttlState = new SingleKeyTTLState(ttlMode, stateName, store, - batchTimestampMs, eventTimeWatermarkMs, this) + batchTimestampMs, eventTimeWatermarkMs) ttlState = Some(_ttlState) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 8893fa118cf3..539211393214 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -34,35 +34,6 @@ object TransformWithStateSuiteUtils { val NUM_SHUFFLE_PARTITIONS = 5 } -class RunningCountStatefulProcessorZeroTTL - extends StatefulProcessor[String, String, (String, String)] - with Logging { - @transient private var _countState: ValueState[Long] = _ - - override def init( - outputMode: OutputMode, - timeoutMode: TimeoutMode) : Unit = { - assert(getHandle.getQueryInfo().getBatchId >= 0) - _countState = getHandle.getValueState("countState", Encoders.scalaLong) - } - - override def handleInputRows( - key: String, - inputRows: Iterator[String], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { - val count = _countState.getOption().getOrElse(0L) + 1 - if (count == 3) { - _countState.clear() - Iterator.empty - } else { - _countState.update(count) - Iterator((key, count.toString)) - } - } - - override def close(): Unit = {} -} class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)] with Logging { @transient protected var _countState: ValueState[Long] = _ @@ -279,41 +250,6 @@ class RunningCountMostRecentStatefulProcessor } } -class RunningCountMostRecentStatefulProcessorWithTTL - extends StatefulProcessor[String, (String, String), (String, String, String)] - with Logging { - @transient private var _countState: ValueState[Long] = _ - @transient private var _mostRecent: ValueState[String] = _ - - override def init( - outputMode: OutputMode, - timeoutMode: TimeoutMode) : Unit = { - assert(getHandle.getQueryInfo().getBatchId >= 0) - _countState = getHandle.getValueState( - "countState", Encoders.scalaLong) - _mostRecent = getHandle.getValueState("mostRecent", Encoders.STRING) - } - - override def handleInputRows( - key: String, - inputRows: Iterator[(String, String)], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, String)] = { - val count = _countState.getOption().getOrElse(0L) + 1 - val mostRecent = _mostRecent.getOption().getOrElse("") - - var output = List[(String, String, String)]() - inputRows.foreach { row => - _mostRecent.update(row._2) - _countState.update(count) - output = (key, count.toString, mostRecent) :: output - } - output.iterator - } - - override def close(): Unit = {} -} - class MostRecentStatefulProcessorWithDeletion extends StatefulProcessor[String, (String, String), (String, String)] with Logging { @@ -633,61 +569,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - Zero duration TTL, should expire immediately") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - val inputData = MemoryStream[String] - - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorZeroTTL(), - TimeoutMode.NoTimeouts(), - TTLMode.NoTTL(), - OutputMode.Update()) - - // State should expire immediately, meaning each answer is independent - // of previous counts - testStream(result, OutputMode.Update())( - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), - AddData(inputData, "a", "b"), - CheckNewAnswer(("a", "1"), ("b", "1")), - AddData(inputData, "a", "b"), - CheckNewAnswer(("a", "1"), ("b", "1")), - StopStream - ) - } - } - - test("transformWithState - multiple state variables with one TTL") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - val inputData = MemoryStream[(String, String)] - val stream1 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new RunningCountMostRecentStatefulProcessorWithTTL(), - TimeoutMode.NoTimeouts(), - TTLMode.NoTTL(), - OutputMode.Update()) - - // State should expire immediately, meaning each answer is independent - // of previous counts - testStream(stream1, OutputMode.Update())( - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "1", "")), - AddData(inputData, ("a", "str2"), ("b", "str3")), - CheckNewAnswer(("a", "1", "str1"), ("b", "1", "")), - AddData(inputData, ("a", "str4"), ("b", "str5")), - CheckNewAnswer(("a", "1", "str2"), ("b", "1", "str3")), - StopStream - ) - } - } - test("transformWithState - two input streams") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala new file mode 100644 index 000000000000..a854781c5e2c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.time.Duration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StreamManualClock + +class ValueStateTTLProcessor + extends StatefulProcessor[String, String, (String, Long)] + with Logging { + + @transient private var _countState: ValueState[Long] = _ + + override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { + _countState = getHandle.getValueState("countState", Encoders.scalaLong) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Long)] = { + + val currValueOption = _countState.getOption() + + var totalLogins: Long = inputRows.size + if (currValueOption.isDefined) { + totalLogins = totalLogins + currValueOption.get + } + + _countState.update(totalLogins, Duration.ofMinutes(1)) + + Iterator.single((key, totalLogins)) + } +} + +class TransformWithStateTTLSuite + extends StreamTest { + import testImplicits._ + + test("validate state is evicted at ttl expiry") { + + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[String] + val result = inputStream.toDS() + .groupByKey(x => x) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, "k1"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("k1", 1L)), + // advance clock so that state expires + AdvanceManualClock(60 * 1000), + AddData(inputStream, "k1"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("k1", 1)) + ) + } + } +} From a79183801d428ede6e7515bc53c8e36517cd537c Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Fri, 22 Mar 2024 16:44:07 -0700 Subject: [PATCH 03/21] Improved documentation, added NERF error classes for user errors. --- .../main/resources/error/error-classes.json | 6 ++ .../apache/spark/sql/streaming/TTLMode.java | 16 ++--- .../sql/catalyst/plans/logical/TTLMode.scala | 2 +- .../streaming/StatefulProcessorHandle.scala | 6 +- .../streaming/StateTypesEncoderUtils.scala | 32 +++++++++- ...cala => StateVariableWithTTLSupport.scala} | 49 ++++++++++++-- .../StatefulProcessorHandleImpl.scala | 10 ++- .../streaming/TransformWithStateExec.scala | 52 ++++++++++----- .../execution/streaming/ValueStateImpl.scala | 33 +++++----- .../execution/streaming/state/RocksDB.scala | 4 -- .../state/RocksDBStateStoreProvider.scala | 7 -- .../streaming/state/StateStore.scala | 5 -- .../streaming/state/StateStoreErrors.scala | 33 ++++++---- .../TransformWithStateTTLSuite.scala | 64 +++++++++++++------ 14 files changed, 210 insertions(+), 109 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{StateVariableTTLSupport.scala => StateVariableWithTTLSupport.scala} (69%) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index bf9628da733e..9d2524751317 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3541,6 +3541,12 @@ ], "sqlState" : "0A000" }, + "STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE" : { + "message" : [ + "State store operation= on state= does not support TTL in NoTTL() mode." + ], + "sqlState" : "42802" + }, "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE" : { "message" : [ "Failed to perform stateful processor operation= with invalid handle state=." diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java index 210f4b78eb84..016407db5e10 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java @@ -22,30 +22,30 @@ import org.apache.spark.sql.catalyst.plans.logical.*; /** - * Represents the type of ttl modes possible for user defined state - * in [[StatefulProcessor]]. + * Represents the type of ttl modes possible for the Dataset operations + * {@code transformWithState}. */ @Experimental @Evolving public class TTLMode { /** - * Specifies that there is no TTL for the state object. Such objects would not + * Specifies that there is no TTL for the user state. User state would not * be cleaned up by Spark automatically. */ - public static final TTLMode NoTTL() { + public static TTLMode NoTTL() { return NoTTL$.MODULE$; } /** - * Specifies that the specified ttl is in processing time. + * Specifies that all ttl durations for user state are in processing time. */ - public static final TTLMode ProcessingTimeTTL() { + public static TTLMode ProcessingTimeTTL() { return ProcessingTimeTTL$.MODULE$; } /** - * Specifies that the specified ttl is in event time. + * Specifies that all ttl durations for user state are in event time. */ - public static final TTLMode EventTimeTTL() { return EventTimeTTL$.MODULE$; } + public static TTLMode EventTimeTTL() { return EventTimeTTL$.MODULE$; } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala index e0e02868fbf4..4b43060e0d35 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.streaming.TTLMode -/** Types of timeouts used in tranformWithState operator */ +/** TTL types used in tranformWithState operator */ case object NoTTL extends TTLMode case object ProcessingTimeTTL extends TTLMode diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala index 10a914e11247..30f2d9000ecc 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala @@ -31,19 +31,15 @@ private[sql] trait StatefulProcessorHandle extends Serializable { /** * Function to create new or return existing single value state variable of given type. - * The state will be eventually cleaned up after the specified ttl. * The user must ensure to call this function only within the `init()` method of the * StatefulProcessor. * * @param stateName - name of the state variable * @param valEncoder - SQL encoder for state variable - * @param ttlMode - ttl mode for the state * @tparam T - type of state variable * @return - instance of ValueState of type T that can be used to store state persistently */ - def getValueState[T]( - stateName: String, - valEncoder: Encoder[T]): ValueState[T] + def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] /** * Creates new or returns existing list state associated with stateName. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index f5509abfc52f..f97b91ccc960 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -70,6 +70,8 @@ class StateTypesEncoder[GK, V]( private val rowToObjDeserializer = valExpressionEnc.resolveAndBind().createDeserializer() private val reusedValRow = new UnsafeRow(valEncoder.schema.fields.length) + private val NO_TTL_ENCODED_VALUE: Long = -1L + // TODO: validate places that are trying to encode the key and check if we can eliminate/ // add caching for some of these calls. def encodeGroupingKey(): UnsafeRow = { @@ -77,6 +79,12 @@ class StateTypesEncoder[GK, V]( keyRow } + /** + * Encodes the provided grouping key into Spark UnsafeRow. + * + * @param groupingKeyBytes serialized grouping key byte array + * @return encoded UnsafeRow + */ def encodeSerializedGroupingKey( groupingKeyBytes: Array[Byte]): UnsafeRow = { val keyRow = keyProjection(InternalRow(groupingKeyBytes)) @@ -92,13 +100,21 @@ class StateTypesEncoder[GK, V]( keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() } + /** + * Encode the specified value in Spark UnsafeRow with no ttl. + * The ttl expiration will be set to -1, specifying no TTL. + */ def encodeValue(value: V): UnsafeRow = { val objRow: InternalRow = objToRowSerializer.apply(value) val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() - val valRow = valueProjection(InternalRow(bytes, 0L)) + val valRow = valueProjection(InternalRow(bytes, NO_TTL_ENCODED_VALUE)) valRow } + /** + * Encode the specified value in Spark UnsafeRow + * with provided ttl expiration. + */ def encodeValue(value: V, expirationMs: Long = -1): UnsafeRow = { val objRow: InternalRow = objToRowSerializer.apply(value) val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() @@ -113,9 +129,19 @@ class StateTypesEncoder[GK, V]( value } - def decodeTtlExpirationMs(row: UnsafeRow): Long = { + /** + * Decode the ttl information out of Value row. If the ttl has + * not been set (-1L specifies no user defined value), the API will + * return None. + */ + def decodeTtlExpirationMs(row: UnsafeRow): Option[Long] = { val expirationMs = row.getLong(1) - expirationMs + + if (expirationMs == NO_TTL_ENCODED_VALUE) { + None + } else { + Some(expirationMs) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala similarity index 69% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala index 07fa8c141b15..bfddb38f17bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala @@ -33,14 +33,48 @@ object StateTTLSchema { StructType(Array(StructField("__dummy__", NullType))) } -trait StateVariableTTLSupport { +/** + * Represents a State variable which supports TTL. + */ +trait StateVariableWithTTLSupport { + + /** + * Clears the user state associated with this grouping key + * if it has expired. This function is called by Spark to perform + * cleanup at the end of transformWithState processing. + * + * Spark uses a secondary index to determine if the user state for + * this grouping key has expired. However, its possible that the user + * has updated the TTL and secondary index is out of date. Implementations + * must validate that the user State has actually expired before cleanup based + * on their own State data. + * + * @param groupingKey grouping key for which cleanup should be performed. + */ def clearIfExpired(groupingKey: Array[Byte]): Unit } +/** + * Represents the underlying state for secondary TTL Index for a user defined + * state variable. + * + * This state allows Spark to query ttl values based on expiration time + * allowing efficient ttl cleanup. + */ trait TTLState { + + /** + * Perform the user state clean yp based on ttl values stored in + * this state. NOTE that its not safe to call this operation concurrently + * when the user can also modify the underlying State. Cleanup should be initiated + * after arbitrary state operations are completed by the user. + */ def clearExpiredState(): Unit } +/** + * Manages the ttl information for user state keyed with a single key (grouping key). + */ class SingleKeyTTLState( ttlMode: TTLMode, stateName: String, @@ -54,7 +88,7 @@ class SingleKeyTTLState( private val ttlColumnFamilyName = s"_ttl_$stateName" private val ttlKeyEncoder = UnsafeProjection.create(KEY_ROW_SCHEMA) - private var state: StateVariableTTLSupport = _ + private var state: StateVariableWithTTLSupport = _ // empty row used for values private val EMPTY_ROW = @@ -87,11 +121,14 @@ class SingleKeyTTLState( } private[sql] def setStateVariable( - state: StateVariableTTLSupport): Unit = { + state: StateVariableWithTTLSupport): Unit = { this.state = state } } +/** + * Helper methods for user State TTL. + */ object StateTTL { def calculateExpirationTimeForDuration( ttlMode: TTLMode, @@ -103,7 +140,8 @@ object StateTTL { } else if (ttlMode == TTLMode.EventTimeTTL()) { eventTimeWatermarkMs.get + ttlDuration.toMillis } else { - -1L + throw new IllegalStateException(s"cannot calculate expiration time for" + + s" unknown ttl Mode $ttlMode") } } @@ -117,7 +155,8 @@ object StateTTL { } else if (ttlMode == TTLMode.EventTimeTTL()) { eventTimeWatermarkMs.get > expirationMs } else { - false + throw new IllegalStateException(s"cannot evaluate expiry condition for" + + s" unknown ttl Mode $ttlMode") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index b74d09232421..083e12e98be4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -110,12 +110,6 @@ class StatefulProcessorHandleImpl( private var currState: StatefulProcessorHandleState = CREATED - private def verify(condition: => Boolean, msg: String): Unit = { - if (!condition) { - throw new IllegalStateException(msg) - } - } - def setHandleState(newState: StatefulProcessorHandleState): Unit = { currState = newState } @@ -203,6 +197,10 @@ class StatefulProcessorHandleImpl( timerState.listTimers() } + /** + * Performs the user state cleanup based on assigned TTl values. Any state + * which is expired will be cleaned up from StateStore. + */ def doTtlCleanup(): Unit = { ttlStates.forEach { s => s.clearExpiredState() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 30de84d550bb..ec65caee899a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Uti * @param groupingAttributes used to group the data * @param dataAttributes used to read the data * @param statefulProcessor processor methods called on underlying data - * @param ttlMode defines the ttl Mode + * @param ttlMode defines the ttl Mode for user state * @param timeoutMode defines the timeout mode * @param outputMode defines the output mode for the statefulProcessor * @param keyEncoder expression encoder for the key type @@ -281,7 +281,7 @@ case class TransformWithStateExec( allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) commitTimeMs += timeTakenMs { if (isStreaming) { - // join ttlBackgroundThread forkjoinpool + // clean up any expired user state processorHandle.doTtlCleanup() store.commit() } else { @@ -299,20 +299,8 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver - // TODO(sahnib) add validation for ttlMode - timeoutMode match { - case ProcessingTime => - if (batchTimestampMs.isEmpty) { - StateStoreErrors.missingTimeoutValues(timeoutMode.toString) - } - - case EventTime => - if (eventTimeWatermarkForEviction.isEmpty) { - StateStoreErrors.missingTimeoutValues(timeoutMode.toString) - } - - case _ => - } + validateTTLMode() + validateTimeoutMode() if (hasInitialState) { val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf) @@ -463,6 +451,38 @@ case class TransformWithStateExec( processDataWithPartition(childDataIterator, store, processorHandle) } + + private def validateTimeoutMode(): Unit = { + timeoutMode match { + case ProcessingTime => + if (batchTimestampMs.isEmpty) { + StateStoreErrors.missingTimeoutValues(timeoutMode.toString) + } + + case EventTime => + if (eventTimeWatermarkForEviction.isEmpty) { + StateStoreErrors.missingTimeoutValues(timeoutMode.toString) + } + + case _ => + } + } + + private def validateTTLMode(): Unit = { + ttlMode match { + case ProcessingTimeTTL => + if (batchTimestampMs.isEmpty) { + StateStoreErrors.missingTTLValues(timeoutMode.toString) + } + + case EventTimeTTL => + if (eventTimeWatermarkForEviction.isEmpty) { + StateStoreErrors.missingTTLValues(timeoutMode.toString) + } + + case _ => + } + } } // scalastyle:off argcount diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 7cfc873ef3a5..11f4c8e1f461 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -21,8 +21,9 @@ import java.time.Duration import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{TTLMode, ValueState} /** @@ -32,9 +33,10 @@ import org.apache.spark.sql.streaming.{TTLMode, ValueState} * @param stateName - name of logical state partition * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value - * @param ttlMode ttl Mode to evict expired values from state + * @param ttlMode - TTL Mode for values stored in this state * @param batchTimestampMs processing timestamp of the current batch. - * @param eventTimeWatermarkMs event time watermark for state eviction + * @param eventTimeWatermarkMs event time watermark for streaming query + * (same as watermark for state eviction) * @tparam S - data type of object that will be stored */ class ValueStateImpl[S]( @@ -47,7 +49,7 @@ class ValueStateImpl[S]( eventTimeWatermarkMs: Option[Long]) extends ValueState[S] with Logging - with StateVariableTTLSupport { + with StateVariableWithTTLSupport { private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) @@ -84,11 +86,7 @@ class ValueStateImpl[S]( if (retRow != null) { val resState = stateTypesEncoder.decodeValue(retRow) - val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(retRow) - val isExpired = StateTTL.isExpired(ttlMode, - expirationMs, batchTimestampMs, eventTimeWatermarkMs) - - if (!isExpired) { + if (!isExpired(retRow)) { resState } else { null.asInstanceOf[S] @@ -104,8 +102,7 @@ class ValueStateImpl[S]( ttlDuration: Duration = Duration.ZERO): Unit = { if (ttlDuration != Duration.ZERO && ttlState.isEmpty) { - // TODO(sahnib) throw a StateStoreError here - throw new RuntimeException() + throw StateStoreErrors.cannotProvideTTLDurationForNoTTLMode("update", stateName) } var expirationMs: Long = -1 @@ -131,13 +128,17 @@ class ValueStateImpl[S]( val retRow = store.get(encodedGroupingKey, stateName) if (retRow != null) { - val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(retRow) - val isExpired = StateTTL.isExpired(ttlMode, - expirationMs, batchTimestampMs, eventTimeWatermarkMs) - - if (!isExpired) { + if (isExpired(retRow)) { store.remove(encodedGroupingKey, stateName) } } } + + private def isExpired(valueRow: UnsafeRow): Boolean = { + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow) + val isExpired = expirationMs.map( + StateTTL.isExpired(ttlMode, _, batchTimestampMs, eventTimeWatermarkMs)) + + isExpired.isDefined && isExpired.get + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 1c1159f65454..1817104a5c22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -342,10 +342,6 @@ class RocksDB( } } - def listColumnFamilies(): Seq[String] = { - Seq() - } - /** * Remove RocksDB column family, if exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index c6db9b5d000c..c8537f2a6a5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -257,13 +257,6 @@ private[sql] class RocksDBStateStoreProvider rocksDB.removeColFamilyIfExists(colFamilyName) keyValueEncoderMap.remove(colFamilyName) } - - /** Return a list of column family names */ - override def listColumnFamilies(): Seq[String] = { - verify(useColumnFamilies, "Column families are not supported in this store") - // turn keys of keyValueEncoderMap to Seq - rocksDB.listColumnFamilies() - } } override def init( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index c79c244dcf44..dd97aa5b9afc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -122,11 +122,6 @@ trait StateStore extends ReadStateStore { */ def removeColFamilyIfExists(colFamilyName: String): Unit - /** - * Return list of column family names. - */ - def listColumnFamilies(): Seq[String] = Seq(StateStore.DEFAULT_COL_FAMILY_NAME) - /** * Create column family with given name, if absent. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index e86670355cae..ad16063653e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -39,6 +39,13 @@ object StateStoreErrors { ) } + def missingTTLValues(ttlMode: String): SparkException = { + SparkException.internalError( + msg = s"Failed to find timeout values for ttlMode=$ttlMode", + category = "TWS" + ) + } + def unsupportedOperationOnMissingColumnFamily(operationName: String, colFamilyName: String): StateStoreUnsupportedOperationOnMissingColumnFamily = { new StateStoreUnsupportedOperationOnMissingColumnFamily(operationName, colFamilyName) @@ -49,11 +56,6 @@ object StateStoreErrors { new StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider) } - def ttlNotSupportedWithProvider(stateStoreProvider: String): - StateStoreTTLNotSupportedException = { - new StateStoreTTLNotSupportedException(stateStoreProvider) - } - def removingColumnFamiliesNotSupported(stateStoreProvider: String): StateStoreRemovingColumnFamiliesNotSupportedException = { new StateStoreRemovingColumnFamiliesNotSupportedException(stateStoreProvider) @@ -122,19 +124,17 @@ object StateStoreErrors { StatefulProcessorCannotReInitializeState = { new StatefulProcessorCannotReInitializeState(groupingKey) } + + def cannotProvideTTLDurationForNoTTLMode(operationType: String, + stateName: String): StatefulProcessorCannotAssignTTLInNoTTLMode = { + new StatefulProcessorCannotAssignTTLInNoTTLMode(operationType, stateName) + } } class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_MULTIPLE_COLUMN_FAMILIES", - messageParameters = Map("stateStoreProvider" -> stateStoreProvider) - ) - -class StateStoreTTLNotSupportedException(stateStoreProvider: String) - extends SparkUnsupportedOperationException( - errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_TTL", - messageParameters = Map("stateStoreProvider" -> stateStoreProvider) - ) + messageParameters = Map("stateStoreProvider" -> stateStoreProvider)) class StateStoreRemovingColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( @@ -204,3 +204,10 @@ class StateStoreNullTypeOrderingColsNotSupported(fieldName: String, index: Strin extends SparkUnsupportedOperationException( errorClass = "STATE_STORE_NULL_TYPE_ORDERING_COLS_NOT_SUPPORTED", messageParameters = Map("fieldName" -> fieldName, "index" -> index)) + +class StatefulProcessorCannotAssignTTLInNoTTLMode( + operationType: String, + stateName: String) + extends SparkUnsupportedOperationException( + errorClass = "STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE", + messageParameters = Map("operationType" -> operationType, "stateName" -> stateName)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala index a854781c5e2c..189398b0319b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala @@ -22,36 +22,53 @@ import java.time.Duration import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.execution.streaming.state.{RocksDB, RocksDBConf, RocksDBStateStoreProvider, StateStoreConf} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock +case class InputEvent( + key: String, + action: String, + value: Int, + ttl: Duration) + +case class OutputEvent( + key: String, + exists: Boolean, + value: Int) + class ValueStateTTLProcessor - extends StatefulProcessor[String, String, (String, Long)] + extends StatefulProcessor[String, InputEvent, OutputEvent] with Logging { - @transient private var _countState: ValueState[Long] = _ + @transient private var _valueState: ValueState[Int] = _ override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { - _countState = getHandle.getValueState("countState", Encoders.scalaLong) + _valueState = getHandle.getValueState("valueState", Encoders.scalaInt) } override def handleInputRows( key: String, - inputRows: Iterator[String], + inputRows: Iterator[InputEvent], timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Long)] = { + expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + var results = List[OutputEvent]() - val currValueOption = _countState.getOption() + for (row <- inputRows) { + if (row.action == "get") { + val currState = _valueState.getOption() - var totalLogins: Long = inputRows.size - if (currValueOption.isDefined) { - totalLogins = totalLogins + currValueOption.get + if (currState.isDefined) { + results = OutputEvent(key, exists = true, currState.get) :: results + } else { + results = OutputEvent(key, exists = false, -1) :: results + } + } else if (row.action == "put") { + _valueState.update(row.value, row.ttl) + } } - _countState.update(totalLogins, Duration.ofMinutes(1)) - - Iterator.single((key, totalLogins)) + results.iterator } } @@ -60,13 +77,12 @@ class TransformWithStateTTLSuite import testImplicits._ test("validate state is evicted at ttl expiry") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - val inputStream = MemoryStream[String] + val inputStream = MemoryStream[InputEvent] val result = inputStream.toDS() - .groupByKey(x => x) + .groupByKey(x => x.key) .transformWithState( new ValueStateTTLProcessor(), TimeoutMode.NoTimeouts(), @@ -75,14 +91,22 @@ class TransformWithStateTTLSuite val clock = new StreamManualClock testStream(result)( StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), - AddData(inputStream, "k1"), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state + AddData(inputStream, InputEvent("k1", "get", 1, null)), + // advance clock to trigger processing AdvanceManualClock(1 * 1000), - CheckNewAnswer(("k1", 1L)), + CheckNewAnswer(OutputEvent("k1", exists = true, 1)), // advance clock so that state expires AdvanceManualClock(60 * 1000), - AddData(inputStream, "k1"), + AddData(inputStream, InputEvent("k1", "get", -1, null)), AdvanceManualClock(1 * 1000), - CheckNewAnswer(("k1", 1)) + // validate state does not exist anymore + CheckNewAnswer(OutputEvent("k1", exists = false, -1)) + // ensure this state does not exist any longer in State ) } } From 2d0c81d40f9a1751527220ac4a8d444e01cbe03e Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Fri, 22 Mar 2024 16:48:40 -0700 Subject: [PATCH 04/21] Regenerate docs for SQL Error conditions. --- docs/sql-error-conditions.md | 6 ++++++ .../spark/sql/streaming/TransformWithStateTTLSuite.scala | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 85b9e85ac420..501938005d83 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -2150,6 +2150,12 @@ The SQL config `` cannot be found. Please verify that the config exists Star (*) is not allowed in a select list when GROUP BY an ordinal position is used. +### STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +State store operation=`` on state=`` does not support TTL in NoTTL() mode. + ### STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE [SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala index 189398b0319b..f74862b21735 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala @@ -22,7 +22,7 @@ import java.time.Duration import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{RocksDB, RocksDBConf, RocksDBStateStoreProvider, StateStoreConf} +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock From e0404e030873590f34f2b05a92b8c1608a9fd753 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Mon, 25 Mar 2024 18:04:23 -0700 Subject: [PATCH 05/21] Add more testcases for value state ttl. --- .../apache/spark/sql/streaming/TTLMode.java | 6 +- .../spark/sql/streaming/ValueState.scala | 8 +- .../StateVariableWithTTLSupport.scala | 30 +- .../streaming/TransformWithStateExec.scala | 1 + .../execution/streaming/ValueStateImpl.scala | 75 ++- .../TransformWithStateTTLSuite.scala | 461 +++++++++++++++++- 6 files changed, 548 insertions(+), 33 deletions(-) diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java index 016407db5e10..b18a4b136828 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java @@ -33,19 +33,19 @@ public class TTLMode { * Specifies that there is no TTL for the user state. User state would not * be cleaned up by Spark automatically. */ - public static TTLMode NoTTL() { + public static final TTLMode NoTTL() { return NoTTL$.MODULE$; } /** * Specifies that all ttl durations for user state are in processing time. */ - public static TTLMode ProcessingTimeTTL() { + public static final TTLMode ProcessingTimeTTL() { return ProcessingTimeTTL$.MODULE$; } /** * Specifies that all ttl durations for user state are in event time. */ - public static TTLMode EventTimeTTL() { return EventTimeTTL$.MODULE$; } + public static final TTLMode EventTimeTTL() { return EventTimeTTL$.MODULE$; } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala index 36ee12fa83e1..11e69dd08bf1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala @@ -43,8 +43,12 @@ private[sql] trait ValueState[S] extends Serializable { /** Get the state if it exists as an option and None otherwise */ def getOption(): Option[S] - /** Update the value of the state. */ - // TODO(sahnib) confirm if this should be scala or Java type of Duration + /** + * Update the value of the state. + * @param newState the new value + * @param ttlDuration set the ttl to current batch processing time (for processing time TTL mode) + * or current watermark (for event time ttl mode) plus ttlDuration + */ def update(newState: S, ttlDuration: Duration = Duration.ZERO): Unit /** Remove this state. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala index bfddb38f17bd..c46b03a9dcae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala @@ -33,6 +33,15 @@ object StateTTLSchema { StructType(Array(StructField("__dummy__", NullType))) } +/** + * Encapsulates the ttl row information stored in [[SingleKeyTTLState]]. + * @param groupingKey grouping key for which ttl is set + * @param expirationMs expiration time for the grouping key + */ +case class SingleKeyTTLRow( + groupingKey: Array[Byte], + expirationMs: Long) + /** * Represents a State variable which supports TTL. */ @@ -114,7 +123,6 @@ class SingleKeyTTLState( val groupingKey = kv.key.getBinary(1) state.clearIfExpired(groupingKey) - // TODO(sahnib) ; validate its safe to update inside iterator store.remove(kv.key, ttlColumnFamilyName) } } @@ -124,6 +132,22 @@ class SingleKeyTTLState( state: StateVariableWithTTLSupport): Unit = { this.state = state } + + private[sql] def iterator(): Iterator[SingleKeyTTLRow] = { + val ttlIterator = store.iterator(ttlColumnFamilyName) + + new Iterator[SingleKeyTTLRow] { + override def hasNext: Boolean = ttlIterator.hasNext + + override def next(): SingleKeyTTLRow = { + val kv = ttlIterator.next() + SingleKeyTTLRow( + expirationMs = kv.key.getLong(0), + groupingKey = kv.key.getBinary(1) + ) + } + } + } } /** @@ -151,9 +175,9 @@ object StateTTL { batchTimestampMs: Option[Long], eventTimeWatermarkMs: Option[Long]): Boolean = { if (ttlMode == TTLMode.ProcessingTimeTTL()) { - batchTimestampMs.get > expirationMs + batchTimestampMs.get >= expirationMs } else if (ttlMode == TTLMode.EventTimeTTL()) { - eventTimeWatermarkMs.get > expirationMs + eventTimeWatermarkMs.get >= expirationMs } else { throw new IllegalStateException(s"cannot evaluate expiry condition for" + s" unknown ttl Mode $ttlMode") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index ec65caee899a..929fd22c268e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -469,6 +469,7 @@ case class TransformWithStateExec( } private def validateTTLMode(): Unit = { + logWarning(s"Validating ttl Mode - $ttlMode $eventTimeWatermarkForEviction") ttlMode match { case ProcessingTimeTTL => if (batchTimestampMs.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 11f4c8e1f461..8d9eaeb39f91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -106,7 +106,7 @@ class ValueStateImpl[S]( } var expirationMs: Long = -1 - if (ttlDuration != Duration.ZERO) { + if (ttlDuration != null && ttlDuration != Duration.ZERO) { expirationMs = StateTTL.calculateExpirationTimeForDuration( ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) } @@ -141,4 +141,77 @@ class ValueStateImpl[S]( isExpired.isDefined && isExpired.get } + + /* + * Internal methods to probe state for testing. The below methods exist for unit tests + * to read the state ttl values, and ensure that values are persisted correctly in + * the underlying state store. + */ + + /** + * Retrieves the value from State even if its expired. This method is used + * in tests to read the state store value, and ensure if its cleaned up at the + * end of the micro-batch. + */ + private[sql] def getWithoutEnforcingTTL(): Option[S] = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + val resState = stateTypesEncoder.decodeValue(retRow) + Some(resState) + } else { + None + } + } + + /** + * Read the ttl value associated with the grouping key. + */ + private[sql] def getTTLValue(): Option[Long] = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + stateTypesEncoder.decodeTtlExpirationMs(retRow) + } else { + None + } + } + + /** + * Get all ttl values stored in ttl state for current implicit + * grouping key. + */ + private[sql] def getValuesInTTLState(): Iterator[Long] = { + if (ttlState.isEmpty) { + Iterator.empty + } + + val ttlIterator = ttlState.get.iterator() + val implicitGroupingKey = stateTypesEncoder.serializeGroupingKey() + var nextValue: Option[Long] = None + + new Iterator[Long] { + override def hasNext: Boolean = { + while (nextValue.isEmpty && ttlIterator.hasNext) { + val nextTtlValue = ttlIterator.next() + val groupingKey = nextTtlValue.groupingKey + + if (groupingKey sameElements implicitGroupingKey) { + nextValue = Some(nextTtlValue.expirationMs) + } + } + + nextValue.isDefined + } + + override def next(): Long = { + val result = nextValue.get + nextValue = None + + result + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala index f74862b21735..a1acf3689d0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.streaming -import java.time.Duration +import java.sql.Timestamp +import java.time.{Duration, Instant} +import java.time.temporal.ChronoUnit import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders -import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.{MemoryStream, ValueStateImpl} import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -30,21 +32,59 @@ case class InputEvent( key: String, action: String, value: Int, - ttl: Duration) + ttl: Duration, + eventTime: Timestamp = null) case class OutputEvent( key: String, - exists: Boolean, - value: Int) + value: Int, + isTTLValue: Boolean, + ttlValue: Long) + +object TTLInputProcessFunction { + def processRow( + row: InputEvent, + valueState: ValueStateImpl[Int]): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + val key = row.key + if (row.action == "get") { + val currState = valueState.getOption() + if (currState.isDefined) { + results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_without_enforcing_ttl") { + val currState = valueState.getWithoutEnforcingTTL() + if (currState.isDefined) { + results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_ttl_value_from_state") { + val ttlExpiration = valueState.getTTLValue() + if (ttlExpiration.isDefined) { + results = OutputEvent(key, -1, isTTLValue = true, ttlExpiration.get) :: results + } + } else if (row.action == "put") { + valueState.update(row.value, row.ttl) + } else if (row.action == "get_values_in_ttl_state") { + val ttlValues = valueState.getValuesInTTLState() + ttlValues.foreach { v => + results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: results + } + } + + results.iterator + } +} class ValueStateTTLProcessor extends StatefulProcessor[String, InputEvent, OutputEvent] with Logging { - @transient private var _valueState: ValueState[Int] = _ + @transient private var _valueState: ValueStateImpl[Int] = _ override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { - _valueState = getHandle.getValueState("valueState", Encoders.scalaInt) + _valueState = getHandle + .getValueState("valueState", Encoders.scalaInt) + .asInstanceOf[ValueStateImpl[Int]] } override def handleInputRows( @@ -55,16 +95,9 @@ class ValueStateTTLProcessor var results = List[OutputEvent]() for (row <- inputRows) { - if (row.action == "get") { - val currState = _valueState.getOption() - - if (currState.isDefined) { - results = OutputEvent(key, exists = true, currState.get) :: results - } else { - results = OutputEvent(key, exists = false, -1) :: results - } - } else if (row.action == "put") { - _valueState.update(row.value, row.ttl) + val resultIter = TTLInputProcessFunction.processRow(row, _valueState) + resultIter.foreach { r => + results = r :: results } } @@ -72,11 +105,51 @@ class ValueStateTTLProcessor } } +case class MultipleValueStatesTTLProcessor( + ttlKey: String, + noTtlKey: String) + extends StatefulProcessor[String, InputEvent, OutputEvent] + with Logging { + + @transient private var _valueStateWithTTL: ValueStateImpl[Int] = _ + @transient private var _valueStateWithoutTTL: ValueStateImpl[Int] = _ + + override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { + _valueStateWithTTL = getHandle + .getValueState("valueState", Encoders.scalaInt) + .asInstanceOf[ValueStateImpl[Int]] + _valueStateWithoutTTL = getHandle + .getValueState("valueState", Encoders.scalaInt) + .asInstanceOf[ValueStateImpl[Int]] + } + + override def handleInputRows( + key: String, + inputRows: Iterator[InputEvent], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + val state = if (key == ttlKey) { + _valueStateWithTTL + } else { + _valueStateWithoutTTL + } + + for (row <- inputRows) { + val resultIterator = TTLInputProcessFunction.processRow(row, state) + resultIterator.foreach { r => + results = r :: results + } + } + results.iterator + } +} + class TransformWithStateTTLSuite extends StreamTest { import testImplicits._ - test("validate state is evicted at ttl expiry") { + test("validate state is evicted at ttl expiry - processing time ttl") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -95,18 +168,358 @@ class TransformWithStateTTLSuite // advance clock to trigger processing AdvanceManualClock(1 * 1000), CheckNewAnswer(), - // get this state - AddData(inputStream, InputEvent("k1", "get", 1, null)), - // advance clock to trigger processing + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), AdvanceManualClock(1 * 1000), - CheckNewAnswer(OutputEvent("k1", exists = true, 1)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), // advance clock so that state expires AdvanceManualClock(60 * 1000), AddData(inputStream, InputEvent("k1", "get", -1, null)), AdvanceManualClock(1 * 1000), - // validate state does not exist anymore - CheckNewAnswer(OutputEvent("k1", exists = false, -1)) + // validate expired value is not returned + CheckNewAnswer(), + // ensure this state does not exist any longer in State + AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + test("validate ttl update updates the expiration timestamp - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock and update expiration time + AdvanceManualClock(30 * 1000), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent("k1", "get", -1, null)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + // validate value is not expired + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate ttl value is updated in the state + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)), + // validate ttl state has both ttl values present + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000), + OutputEvent("k1", -1, isTTLValue = true, 95000) + ), + // advance clock after older expiration value + AdvanceManualClock(30 * 1000), + // ensure unexpired value is still present in the state + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate that the older expiration value is removed from ttl state + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)) + ) + } + } + + test("validate ttl removal keeps value in state - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock and update state to remove ttl + AdvanceManualClock(30 * 1000), + AddData(inputStream, InputEvent("k1", "put", 1, null)), + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + // validate value is not expired + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // validate ttl state still has old ttl value present + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock after older expiration value + AdvanceManualClock(30 * 1000), + // ensure unexpired value is still present in the state + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate that the older expiration value is removed from ttl state + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + test("validate multiple value states - with and without ttl - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val ttlKey = "k1" + val noTtlKey = "k2" + + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + MultipleValueStatesTTLProcessor(ttlKey, noTtlKey), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent(ttlKey, "put", 1, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent(noTtlKey, "put", 2, Duration.ZERO)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get both state values, and make sure we get unexpired value + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer( + OutputEvent(ttlKey, 1, isTTLValue = false, -1), + OutputEvent(noTtlKey, 2, isTTLValue = false, -1) + ), + // ensure ttl values were added correctly, and noTtlKey has no ttl values + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent(ttlKey, "get_values_in_ttl_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + // advance clock after expiry + AdvanceManualClock(60 * 1000), + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + // validate ttlKey is expired, bot noTtlKey is still present + CheckNewAnswer(OutputEvent(noTtlKey, 2, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + test("validate state is evicted at ttl expiry - event time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .withWatermark("eventTime", "1 second") + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.EventTimeTTL()) + + val eventTime1 = Timestamp.from(Instant.EPOCH.plus(30, ChronoUnit.SECONDS)) + val eventTime2 = Timestamp.from(Instant.EPOCH.plus(60, ChronoUnit.SECONDS)) + val ttlDuration = Duration.ofMinutes(1) + val ttlExpirationMs = Instant.EPOCH.plus(ttlDuration).toEpochMilli + val eventTime3 = Timestamp.from(Instant.ofEpochMilli(ttlExpirationMs + 1000)) + + testStream(result)( + AddData(inputStream, + InputEvent("k1", "put", 1, ttlDuration, eventTime1)), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, + InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // increment event time so that key k1 expires + AddData(inputStream, InputEvent("k2", "put", 1, ttlDuration, eventTime3)), + CheckNewAnswer(), + // validate that k1 has expired + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime3)), + CheckNewAnswer(), // ensure this state does not exist any longer in State + AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", -1, null, eventTime3)), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime3)), + CheckNewAnswer() + ) + } + } + + test("validate ttl update updates the expiration timestamp - event time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .withWatermark("eventTime", "1 second") + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.EventTimeTTL()) + + val eventTime1 = Timestamp.from(Instant.EPOCH.plus(30, ChronoUnit.SECONDS)) + val eventTime2 = Timestamp.from(Instant.EPOCH.plus(60, ChronoUnit.SECONDS)) + val ttlDuration = Duration.ofMinutes(1) + val ttlExpirationMs = Instant.EPOCH.plus(ttlDuration).toEpochMilli + val eventTime3 = Timestamp.from(Instant.ofEpochMilli(ttlExpirationMs + 1000)) + + testStream(result)( + AddData(inputStream, + InputEvent("k1", "put", 1, ttlDuration, eventTime1)), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", 1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, + InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // remove tll expiration for key k1, and move watermark past previous ttl value + AddData(inputStream, InputEvent("k1", "put", 2, null, eventTime3)), + CheckNewAnswer(), + // validate that the key still exists + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime3)), + CheckNewAnswer(OutputEvent("k1", 2, isTTLValue = false, -1)), + // ensure this ttl expiration time has been removed from state + AddData(inputStream, + InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + ProcessAllAvailable() + ) + } + } + + test("validate ttl removal keeps value in state - event time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .withWatermark("eventTime", "1 second") + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.EventTimeTTL()) + + val eventTime1 = Timestamp.from(Instant.EPOCH.plus(30, ChronoUnit.SECONDS)) + val eventTime2 = Timestamp.from(Instant.EPOCH.plus(60, ChronoUnit.SECONDS)) + val ttlDuration = Duration.ofMinutes(1) + val ttlExpirationMs = Instant.EPOCH.plus(ttlDuration).toEpochMilli + val eventTime3 = Timestamp.from(Instant.ofEpochMilli(ttlExpirationMs + 1000)) + + testStream(result)( + AddData(inputStream, InputEvent("k1", "put", 1, ttlDuration, eventTime1)), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", 1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // update state and remove ttl + AddData(inputStream, InputEvent("k1", "put", 2, null, eventTime2)), + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime2)), + // validate value is not expired + CheckNewAnswer(OutputEvent("k1", 2, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(), + // validate ttl state still has old ttl value present + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime3)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // eventTime has been advanced to eventTim3 which is after older expiration value + // ensure unexpired value is still present in the state + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime3)), + CheckNewAnswer(OutputEvent("k1", 2, isTTLValue = false, -1)), + // validate that the older expiration value is removed from ttl state + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + CheckNewAnswer() ) } } From d87472a136c8a3def13d89ff3c4130dbb4d324bd Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Mon, 25 Mar 2024 21:23:29 -0700 Subject: [PATCH 06/21] Fix indentation. --- .../src/main/java/org/apache/spark/sql/streaming/TTLMode.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java index b18a4b136828..06af92dc1321 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java @@ -40,9 +40,7 @@ public static final TTLMode NoTTL() { /** * Specifies that all ttl durations for user state are in processing time. */ - public static final TTLMode ProcessingTimeTTL() { - return ProcessingTimeTTL$.MODULE$; - } + public static final TTLMode ProcessingTimeTTL() { return ProcessingTimeTTL$.MODULE$; } /** * Specifies that all ttl durations for user state are in event time. From dc916fd634f8ef1fb9328ab786c45a26d9cd9e9b Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Tue, 26 Mar 2024 08:49:31 -0700 Subject: [PATCH 07/21] Add a placeholder TODO to use range scan once available. --- .../sql/execution/streaming/StateVariableWithTTLSupport.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala index c46b03a9dcae..286b926577fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala @@ -103,6 +103,7 @@ class SingleKeyTTLState( private val EMPTY_ROW = UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) + // TODO: use range scan once Range Scan PR is merged for StateStore store.createColFamilyIfAbsent(ttlColumnFamilyName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), isInternal = true) From 93f4807debe9e7c8144855720d2e99a387e59343 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Tue, 26 Mar 2024 11:58:28 -0700 Subject: [PATCH 08/21] Remove unnecessary log in TransformWithStateExec. --- .../spark/sql/execution/streaming/TransformWithStateExec.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 929fd22c268e..ec65caee899a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -469,7 +469,6 @@ case class TransformWithStateExec( } private def validateTTLMode(): Unit = { - logWarning(s"Validating ttl Mode - $ttlMode $eventTimeWatermarkForEviction") ttlMode match { case ProcessingTimeTTL => if (batchTimestampMs.isEmpty) { From 50e75233966dc5d431b18169ad3e08c63a10068c Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Tue, 26 Mar 2024 13:16:13 -0700 Subject: [PATCH 09/21] Rebase with latest master --- .../org/apache/spark/sql/streaming/TransformWithStateSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 539211393214..3fdbc74ca955 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -681,6 +681,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) } From 4d58be07cf5941f72614df107d1c9da53c64a7ea Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Tue, 26 Mar 2024 14:22:34 -0700 Subject: [PATCH 10/21] Suppress method checkstyle for TTLMode. --- dev/checkstyle-suppressions.xml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 7b20dfb6bce5..94dfe20af56e 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -60,6 +60,8 @@ files="sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java"/> + Date: Tue, 26 Mar 2024 19:54:15 -0700 Subject: [PATCH 11/21] Modify Spark Connect tws API to include ttlMode. --- .../scala/org/apache/spark/sql/KeyValueGroupedDataset.scala | 5 ++++- .../scala/org/apache/spark/sql/KeyValueGroupedDataset.scala | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index e5dd67ccf787..5c0f4a7e8c81 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.ScalarUserDefinedFunction import org.apache.spark.sql.functions.col -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode, TTLMode} /** * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -830,12 +830,15 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { * Instance of statefulProcessor whose functions will be invoked by the operator. * @param timeoutMode * The timeout mode of the stateful processor. + * @param ttlMode + * The ttlMode to evict user state on ttl expiration. * @param outputMode * The output mode of the stateful processor. Defaults to APPEND mode. */ def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode = OutputMode.Append()): Dataset[U] = { throw new UnsupportedOperationException } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index f9758e4cbf98..10586cd65963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -656,6 +656,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the * operator. * @param timeoutMode The timeout mode of the stateful processor. + * @param ttlMode The ttlMode to evict user state on ttl expiration * @param outputMode The output mode of the stateful processor. Defaults to APPEND mode. * */ From b9278c7eaa0794021f07f0b17d309f6439f2f38b Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Wed, 27 Mar 2024 10:45:08 -0700 Subject: [PATCH 12/21] Rename ttl changes scala file to TTLState.scala --- .../{StateVariableWithTTLSupport.scala => TTLState.scala} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{StateVariableWithTTLSupport.scala => TTLState.scala} (100%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala From dfb056079092591e8d07969465c1111c4e9e4098 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Thu, 28 Mar 2024 11:28:28 -0700 Subject: [PATCH 13/21] Do not write ttlExpiration in valueRow if ttlMode is set to NoTTL(). --- .../sql/streaming/StatefulProcessor.scala | 3 +- .../execution/streaming/ListStateImpl.scala | 2 +- .../streaming/StateTypesEncoderUtils.scala | 43 ++-- .../StatefulProcessorHandleImpl.scala | 19 +- .../streaming/TransformWithStateExec.scala | 4 +- .../execution/streaming/ValueStateImpl.scala | 136 +---------- .../streaming/ValueStateImplWithTTL.scala | 215 ++++++++++++++++++ .../TransformWithListStateSuite.scala | 6 +- .../TransformWithMapStateSuite.scala | 3 +- .../streaming/TransformWithStateSuite.scala | 17 +- .../TransformWithStateTTLSuite.scala | 30 ++- 11 files changed, 295 insertions(+), 183 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala index 1a61972f0ed0..fdec703412a8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala @@ -44,7 +44,8 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable { */ def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit /** * Function that will allow users to interact with input data rows along with the grouping key diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 662bef5716ea..56c9d2664d9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.ListState diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index f97b91ccc960..c78d368f28ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer @@ -25,17 +26,13 @@ import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.state.StateStoreErrors import org.apache.spark.sql.types.{BinaryType, LongType, StructType} -object StateKeyValueRowSchema { +object TransformWithStateKeyValueRowSchema { val KEY_ROW_SCHEMA: StructType = new StructType().add("key", BinaryType) val VALUE_ROW_SCHEMA: StructType = new StructType() .add("value", BinaryType) + val VALUE_ROW_SCHEMA_WITH_TTL: StructType = new StructType() + .add("value", BinaryType) .add("ttlExpirationMs", LongType) - - def encodeGroupingKeyBytes(keyBytes: Array[Byte]): UnsafeRow = { - val keyProjection = UnsafeProjection.create(KEY_ROW_SCHEMA) - val keyRow = keyProjection(InternalRow(keyBytes)) - keyRow - } } /** @@ -57,12 +54,17 @@ object StateKeyValueRowSchema { class StateTypesEncoder[GK, V]( keySerializer: Serializer[GK], valEncoder: Encoder[V], - stateName: String) { - import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema._ + stateName: String, + hasTtl: Boolean) extends Logging { + import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema._ /** Variables reused for conversions between byte array and UnsafeRow */ private val keyProjection = UnsafeProjection.create(KEY_ROW_SCHEMA) - private val valueProjection = UnsafeProjection.create(VALUE_ROW_SCHEMA) + private val valueProjection = if (hasTtl) { + UnsafeProjection.create(VALUE_ROW_SCHEMA_WITH_TTL) + } else { + UnsafeProjection.create(VALUE_ROW_SCHEMA) + } /** Variables reused for value conversions between spark sql and object */ private val valExpressionEnc = encoderFor(valEncoder) @@ -70,8 +72,6 @@ class StateTypesEncoder[GK, V]( private val rowToObjDeserializer = valExpressionEnc.resolveAndBind().createDeserializer() private val reusedValRow = new UnsafeRow(valEncoder.schema.fields.length) - private val NO_TTL_ENCODED_VALUE: Long = -1L - // TODO: validate places that are trying to encode the key and check if we can eliminate/ // add caching for some of these calls. def encodeGroupingKey(): UnsafeRow = { @@ -107,8 +107,7 @@ class StateTypesEncoder[GK, V]( def encodeValue(value: V): UnsafeRow = { val objRow: InternalRow = objToRowSerializer.apply(value) val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() - val valRow = valueProjection(InternalRow(bytes, NO_TTL_ENCODED_VALUE)) - valRow + valueProjection(InternalRow(bytes)) } /** @@ -118,8 +117,7 @@ class StateTypesEncoder[GK, V]( def encodeValue(value: V, expirationMs: Long = -1): UnsafeRow = { val objRow: InternalRow = objToRowSerializer.apply(value) val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() - val valueRow = valueProjection(InternalRow(bytes, expirationMs)) - valueRow + valueProjection(InternalRow(bytes, expirationMs)) } def decodeValue(row: UnsafeRow): V = { @@ -136,8 +134,7 @@ class StateTypesEncoder[GK, V]( */ def decodeTtlExpirationMs(row: UnsafeRow): Option[Long] = { val expirationMs = row.getLong(1) - - if (expirationMs == NO_TTL_ENCODED_VALUE) { + if (expirationMs == -1) { None } else { Some(expirationMs) @@ -149,8 +146,9 @@ object StateTypesEncoder { def apply[GK, V]( keySerializer: Serializer[GK], valEncoder: Encoder[V], - stateName: String): StateTypesEncoder[GK, V] = { - new StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName) + stateName: String, + hasTtl: Boolean = false): StateTypesEncoder[GK, V] = { + new StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName, hasTtl) } } @@ -159,8 +157,9 @@ class CompositeKeyStateEncoder[GK, K, V]( userKeyEnc: Encoder[K], valEncoder: Encoder[V], schemaForCompositeKeyRow: StructType, - stateName: String) - extends StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName) { + stateName: String, + hasTtl: Boolean = false) + extends StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName, hasTtl) { private val compositeKeyProjection = UnsafeProjection.create(schemaForCompositeKeyRow) private val reusedKeyRow = new UnsafeRow(userKeyEnc.schema.fields.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 083e12e98be4..238c94cddf71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -89,7 +89,6 @@ class StatefulProcessorHandleImpl( private val ttlStates: util.List[TTLState] = new util.ArrayList[TTLState]() private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" - logInfo(s"Created StatefulProcessorHandle") private def buildQueryInfo(): QueryInfo = { val taskCtxOpt = Option(TaskContext.get()) @@ -121,16 +120,18 @@ class StatefulProcessorHandleImpl( valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state") - val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder, - ttlMode, batchTimestampMs, eventTimeWatermarkMs) - val ttlState = resultState.ttlState + if (ttlMode == TTLMode.NoTTL()) { + new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) + } else { + val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName, + keyEncoder, valEncoder, ttlMode, batchTimestampMs, eventTimeWatermarkMs) - ttlState.foreach { s => - ttlStates.add(s) - s.setStateVariable(resultState) - } + val ttlState = valueStateWithTTL.ttlState + ttlState.setStateVariable(valueStateWithTTL) + ttlStates.add(ttlState) - resultState + valueStateWithTTL + } } override def getQueryInfo(): QueryInfo = currQueryInfo diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index ec65caee899a..183f63a42cc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ @@ -419,7 +419,7 @@ case class TransformWithStateExec( isStreaming, batchTimestampMs, eventTimeWatermarkForEviction) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) - statefulProcessor.init(outputMode, timeoutMode) + statefulProcessor.init(outputMode, timeoutMode, ttlMode) processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) processDataWithPartition(singleIterator, store, processorHandle) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 8d9eaeb39f91..9bd61fbe679f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -21,10 +21,9 @@ import java.time.Duration import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} -import org.apache.spark.sql.streaming.{TTLMode, ValueState} +import org.apache.spark.sql.streaming.ValueState /** * Class that provides a concrete implementation for a single value state associated with state @@ -33,23 +32,15 @@ import org.apache.spark.sql.streaming.{TTLMode, ValueState} * @param stateName - name of logical state partition * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value - * @param ttlMode - TTL Mode for values stored in this state - * @param batchTimestampMs processing timestamp of the current batch. - * @param eventTimeWatermarkMs event time watermark for streaming query - * (same as watermark for state eviction) * @tparam S - data type of object that will be stored */ class ValueStateImpl[S]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], - valEncoder: Encoder[S], - ttlMode: TTLMode, - batchTimestampMs: Option[Long], - eventTimeWatermarkMs: Option[Long]) + valEncoder: Encoder[S]) extends ValueState[S] - with Logging - with StateVariableWithTTLSupport { + with Logging { private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) @@ -60,12 +51,6 @@ class ValueStateImpl[S]( private def initialize(): Unit = { store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) - - if (ttlMode != TTLMode.NoTTL()) { - val _ttlState = new SingleKeyTTLState(ttlMode, stateName, store, - batchTimestampMs, eventTimeWatermarkMs) - ttlState = Some(_ttlState) - } } /** Function to check if state exists. Returns true if present and false otherwise */ @@ -84,13 +69,7 @@ class ValueStateImpl[S]( val retRow = store.get(encodedGroupingKey, stateName) if (retRow != null) { - val resState = stateTypesEncoder.decodeValue(retRow) - - if (!isExpired(retRow)) { - resState - } else { - null.asInstanceOf[S] - } + stateTypesEncoder.decodeValue(retRow) } else { null.asInstanceOf[S] } @@ -101,117 +80,18 @@ class ValueStateImpl[S]( newState: S, ttlDuration: Duration = Duration.ZERO): Unit = { - if (ttlDuration != Duration.ZERO && ttlState.isEmpty) { + if (ttlDuration != Duration.ZERO) { throw StateStoreErrors.cannotProvideTTLDurationForNoTTLMode("update", stateName) } - var expirationMs: Long = -1 - if (ttlDuration != null && ttlDuration != Duration.ZERO) { - expirationMs = StateTTL.calculateExpirationTimeForDuration( - ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) - } - + val encodedValue = stateTypesEncoder.encodeValue(newState) val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey), - stateTypesEncoder.encodeValue(newState, expirationMs), stateName) - - ttlState.foreach(_.upsertTTLForStateKey(expirationMs, serializedGroupingKey)) + encodedValue, stateName) } /** Function to remove state for given key */ override def clear(): Unit = { store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) } - - def clearIfExpired(groupingKey: Array[Byte]): Unit = { - val encodedGroupingKey = stateTypesEncoder.encodeSerializedGroupingKey(groupingKey) - val retRow = store.get(encodedGroupingKey, stateName) - - if (retRow != null) { - if (isExpired(retRow)) { - store.remove(encodedGroupingKey, stateName) - } - } - } - - private def isExpired(valueRow: UnsafeRow): Boolean = { - val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow) - val isExpired = expirationMs.map( - StateTTL.isExpired(ttlMode, _, batchTimestampMs, eventTimeWatermarkMs)) - - isExpired.isDefined && isExpired.get - } - - /* - * Internal methods to probe state for testing. The below methods exist for unit tests - * to read the state ttl values, and ensure that values are persisted correctly in - * the underlying state store. - */ - - /** - * Retrieves the value from State even if its expired. This method is used - * in tests to read the state store value, and ensure if its cleaned up at the - * end of the micro-batch. - */ - private[sql] def getWithoutEnforcingTTL(): Option[S] = { - val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() - val retRow = store.get(encodedGroupingKey, stateName) - - if (retRow != null) { - val resState = stateTypesEncoder.decodeValue(retRow) - Some(resState) - } else { - None - } - } - - /** - * Read the ttl value associated with the grouping key. - */ - private[sql] def getTTLValue(): Option[Long] = { - val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() - val retRow = store.get(encodedGroupingKey, stateName) - - if (retRow != null) { - stateTypesEncoder.decodeTtlExpirationMs(retRow) - } else { - None - } - } - - /** - * Get all ttl values stored in ttl state for current implicit - * grouping key. - */ - private[sql] def getValuesInTTLState(): Iterator[Long] = { - if (ttlState.isEmpty) { - Iterator.empty - } - - val ttlIterator = ttlState.get.iterator() - val implicitGroupingKey = stateTypesEncoder.serializeGroupingKey() - var nextValue: Option[Long] = None - - new Iterator[Long] { - override def hasNext: Boolean = { - while (nextValue.isEmpty && ttlIterator.hasNext) { - val nextTtlValue = ttlIterator.next() - val groupingKey = nextTtlValue.groupingKey - - if (groupingKey sameElements implicitGroupingKey) { - nextValue = Some(nextTtlValue.expirationMs) - } - } - - nextValue.isDefined - } - - override def next(): Long = { - val result = nextValue.get - nextValue = None - - result - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala new file mode 100644 index 000000000000..55d2a6581df8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.time.Duration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.streaming.{TTLMode, ValueState} + +/** + * Class that provides a concrete implementation for a single value state associated with state + * variables (with ttl expiration support) used in the streaming transformWithState operator. + * + * @param store - reference to the StateStore instance to be used for storing state + * @param stateName - name of logical state partition + * @param keyExprEnc - Spark SQL encoder for key + * @param valEncoder - Spark SQL encoder for value + * @param ttlMode - TTL Mode for values stored in this state + * @param batchTimestampMs - processing timestamp of the current batch. + * @param eventTimeWatermarkMs - event time watermark for streaming query + * (same as watermark for state eviction) + * @tparam S - data type of object that will be stored + */ +class ValueStateImplWithTTL[S]( + store: StateStore, + stateName: String, + keyExprEnc: ExpressionEncoder[Any], + valEncoder: Encoder[S], + ttlMode: TTLMode, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]) + extends ValueState[S] + with Logging + with StateVariableWithTTLSupport { + + private val keySerializer = keyExprEnc.createSerializer() + private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, + stateName, hasTtl = true) + private[sql] var ttlState: SingleKeyTTLState = _ + + initialize() + + private def initialize(): Unit = { + assert(ttlMode != TTLMode.NoTTL()) + + store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + + ttlState = new SingleKeyTTLState(ttlMode, stateName, store, + batchTimestampMs, eventTimeWatermarkMs) + } + + /** Function to check if state exists. Returns true if present and false otherwise */ + override def exists(): Boolean = { + get() != null + } + + /** Function to return Option of value if exists and None otherwise */ + override def getOption(): Option[S] = { + Option(get()) + } + + /** Function to return associated value with key if exists and null otherwise */ + override def get(): S = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + val resState = stateTypesEncoder.decodeValue(retRow) + + if (!isExpired(retRow)) { + resState + } else { + null.asInstanceOf[S] + } + } else { + null.asInstanceOf[S] + } + } + + /** Function to update and overwrite state associated with given key */ + override def update( + newState: S, + ttlDuration: Duration = Duration.ZERO): Unit = { + + val expirationMs = + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } else { + -1 + } + + val encodedValue = stateTypesEncoder.encodeValue(newState, expirationMs) + + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey), + encodedValue, stateName) + + ttlState.upsertTTLForStateKey(expirationMs, serializedGroupingKey) + } + + /** Function to remove state for given key */ + override def clear(): Unit = { + store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + } + + def clearIfExpired(groupingKey: Array[Byte]): Unit = { + val encodedGroupingKey = stateTypesEncoder.encodeSerializedGroupingKey(groupingKey) + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + if (isExpired(retRow)) { + store.remove(encodedGroupingKey, stateName) + } + } + } + + private def isExpired(valueRow: UnsafeRow): Boolean = { + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow) + val isExpired = expirationMs.map( + StateTTL.isExpired(ttlMode, _, batchTimestampMs, eventTimeWatermarkMs)) + + isExpired.isDefined && isExpired.get + } + + /* + * Internal methods to probe state for testing. The below methods exist for unit tests + * to read the state ttl values, and ensure that values are persisted correctly in + * the underlying state store. + */ + + /** + * Retrieves the value from State even if its expired. This method is used + * in tests to read the state store value, and ensure if its cleaned up at the + * end of the micro-batch. + */ + private[sql] def getWithoutEnforcingTTL(): Option[S] = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + val resState = stateTypesEncoder.decodeValue(retRow) + Some(resState) + } else { + None + } + } + + /** + * Read the ttl value associated with the grouping key. + */ + private[sql] def getTTLValue(): Option[Long] = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + stateTypesEncoder.decodeTtlExpirationMs(retRow) + } else { + None + } + } + + /** + * Get all ttl values stored in ttl state for current implicit + * grouping key. + */ + private[sql] def getValuesInTTLState(): Iterator[Long] = { + val ttlIterator = ttlState.iterator() + val implicitGroupingKey = stateTypesEncoder.serializeGroupingKey() + var nextValue: Option[Long] = None + + new Iterator[Long] { + override def hasNext: Boolean = { + while (nextValue.isEmpty && ttlIterator.hasNext) { + val nextTtlValue = ttlIterator.next() + val groupingKey = nextTtlValue.groupingKey + + if (groupingKey sameElements implicitGroupingKey) { + nextValue = Some(nextTtlValue.expirationMs) + } + } + + nextValue.isDefined + } + + override def next(): Long = { + val result = nextValue.get + nextValue = None + + result + } + } + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 51cc9ff87890..5ccc14ab8a77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -32,7 +32,8 @@ class TestListStateProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _listState = getHandle.getListState("testListState", Encoders.STRING) } @@ -89,7 +90,8 @@ class ToggleSaveAndEmitProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _listState = getHandle.getListState("testListState", Encoders.STRING) _valueState = getHandle.getValueState("testValueState", Encoders.scalaBoolean) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index 3c1280137af9..d32b9687d95f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -32,7 +32,8 @@ class TestMapStateProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _mapState = getHandle.getMapState("sessionState", Encoders.STRING, Encoders.STRING) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 3fdbc74ca955..9b193276aa39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -40,7 +40,8 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) } @@ -103,8 +104,9 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode) : Unit = { - super.init(outputMode, timeoutMode) + timeoutMode: TimeoutMode, + ttlMode: TTLMode) : Unit = { + super.init(outputMode, timeoutMode, ttlMode) _timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong) } @@ -186,7 +188,8 @@ class MaxEventTimeStatefulProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _maxEventTimeState = getHandle.getValueState[Long]("maxEventTimeState", Encoders.scalaLong) _timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong) @@ -227,7 +230,8 @@ class RunningCountMostRecentStatefulProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) _mostRecent = getHandle.getValueState[String]("mostRecent", Encoders.STRING) } @@ -257,7 +261,8 @@ class MostRecentStatefulProcessorWithDeletion override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { getHandle.deleteIfExists("countState") _mostRecent = getHandle.getValueState[String]("mostRecent", Encoders.STRING) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala index a1acf3689d0a..75d98ce0933b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala @@ -23,7 +23,7 @@ import java.time.temporal.ChronoUnit import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders -import org.apache.spark.sql.execution.streaming.{MemoryStream, ValueStateImpl} +import org.apache.spark.sql.execution.streaming.{MemoryStream, ValueStateImplWithTTL} import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -44,7 +44,7 @@ case class OutputEvent( object TTLInputProcessFunction { def processRow( row: InputEvent, - valueState: ValueStateImpl[Int]): Iterator[OutputEvent] = { + valueState: ValueStateImplWithTTL[Int]): Iterator[OutputEvent] = { var results = List[OutputEvent]() val key = row.key if (row.action == "get") { @@ -79,12 +79,17 @@ class ValueStateTTLProcessor extends StatefulProcessor[String, InputEvent, OutputEvent] with Logging { - @transient private var _valueState: ValueStateImpl[Int] = _ + @transient private var _valueState: ValueStateImplWithTTL[Int] = _ + private var _ttlMode: TTLMode = _ - override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _valueState = getHandle .getValueState("valueState", Encoders.scalaInt) - .asInstanceOf[ValueStateImpl[Int]] + .asInstanceOf[ValueStateImplWithTTL[Int]] + _ttlMode = ttlMode } override def handleInputRows( @@ -111,16 +116,19 @@ case class MultipleValueStatesTTLProcessor( extends StatefulProcessor[String, InputEvent, OutputEvent] with Logging { - @transient private var _valueStateWithTTL: ValueStateImpl[Int] = _ - @transient private var _valueStateWithoutTTL: ValueStateImpl[Int] = _ + @transient private var _valueStateWithTTL: ValueStateImplWithTTL[Int] = _ + @transient private var _valueStateWithoutTTL: ValueStateImplWithTTL[Int] = _ - override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _valueStateWithTTL = getHandle .getValueState("valueState", Encoders.scalaInt) - .asInstanceOf[ValueStateImpl[Int]] + .asInstanceOf[ValueStateImplWithTTL[Int]] _valueStateWithoutTTL = getHandle .getValueState("valueState", Encoders.scalaInt) - .asInstanceOf[ValueStateImpl[Int]] + .asInstanceOf[ValueStateImplWithTTL[Int]] } override def handleInputRows( @@ -317,7 +325,7 @@ class TransformWithStateTTLSuite } } - test("validate multiple value states - with and without ttl - processing time ttl") { + test("validate multiple value states - processing time ttl") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { From 5e8485cf7f4f97ab5c41304fd129461fb9dfbcac Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Thu, 28 Mar 2024 11:33:55 -0700 Subject: [PATCH 14/21] Incorporated comments. --- .../execution/streaming/StatefulProcessorHandleImpl.scala | 4 ++++ .../apache/spark/sql/execution/streaming/TTLState.scala | 7 ++++--- .../spark/sql/execution/streaming/ValueStateImpl.scala | 2 +- .../sql/execution/streaming/ValueStateImplWithTTL.scala | 4 ++-- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 238c94cddf71..1c70e676bfd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -86,6 +86,10 @@ class StatefulProcessorHandleImpl( extends StatefulProcessorHandle with Logging { import StatefulProcessorHandleState._ + /** + * Stores all the active ttl states, and is used to cleanup expired values + * in [[doTtlCleanup()]] function. + */ private val ttlStates: util.List[TTLState] = new util.ArrayList[TTLState]() private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala index 286b926577fb..127cb956628a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala @@ -34,7 +34,8 @@ object StateTTLSchema { } /** - * Encapsulates the ttl row information stored in [[SingleKeyTTLState]]. + * Encapsulates the ttl row information stored in [[SingleKeyTTLStateImpl]]. + * * @param groupingKey grouping key for which ttl is set * @param expirationMs expiration time for the grouping key */ @@ -73,7 +74,7 @@ trait StateVariableWithTTLSupport { trait TTLState { /** - * Perform the user state clean yp based on ttl values stored in + * Perform the user state clean up based on ttl values stored in * this state. NOTE that its not safe to call this operation concurrently * when the user can also modify the underlying State. Cleanup should be initiated * after arbitrary state operations are completed by the user. @@ -84,7 +85,7 @@ trait TTLState { /** * Manages the ttl information for user state keyed with a single key (grouping key). */ -class SingleKeyTTLState( +class SingleKeyTTLStateImpl( ttlMode: TTLMode, stateName: String, store: StateStore, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 9bd61fbe679f..964098243b03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -44,7 +44,7 @@ class ValueStateImpl[S]( private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) - private[sql] var ttlState: Option[SingleKeyTTLState] = None + private[sql] var ttlState: Option[SingleKeyTTLStateImpl] = None initialize() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index 55d2a6581df8..17145c1cfa02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -55,7 +55,7 @@ class ValueStateImplWithTTL[S]( private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName, hasTtl = true) - private[sql] var ttlState: SingleKeyTTLState = _ + private[sql] var ttlState: SingleKeyTTLStateImpl = _ initialize() @@ -65,7 +65,7 @@ class ValueStateImplWithTTL[S]( store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) - ttlState = new SingleKeyTTLState(ttlMode, stateName, store, + ttlState = new SingleKeyTTLStateImpl(ttlMode, stateName, store, batchTimestampMs, eventTimeWatermarkMs) } From 21cc40aa8dcbd497320847ed7f8717696520069b Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Thu, 28 Mar 2024 12:24:07 -0700 Subject: [PATCH 15/21] Incorporate TTLMode for initialState changes. --- .../execution/streaming/TransformWithStateExec.scala | 2 +- .../TransformWithStateInitialStateSuite.scala | 11 ++++++++--- .../spark/sql/streaming/TransformWithStateSuite.scala | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 183f63a42cc3..2390e19384f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -433,7 +433,7 @@ case class TransformWithStateExec( keyEncoder, ttlMode, timeoutMode, isStreaming) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) - statefulProcessor.init(outputMode, timeoutMode) + statefulProcessor.init(outputMode, timeoutMode, ttlMode) processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) // Check if is first batch diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala index 9f2e2c2d9f02..031a515d3b79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala @@ -34,7 +34,10 @@ abstract class StatefulProcessorWithInitialStateTestClass[V] @transient var _listState: ListState[Double] = _ @transient var _mapState: MapState[Double, Int] = _ - override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _valState = getHandle.getValueState[Double]("testValueInit", Encoders.scalaDouble) _listState = getHandle.getListState[Double]("testListInit", Encoders.scalaDouble) _mapState = getHandle.getMapState[Double, Int]( @@ -154,7 +157,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest InputRowForInitialState("init_2", 100.0, List(100.0), Map(100.0 -> 1))) .toDS().groupByKey(x => x.key).mapValues(x => x) val query = kvDataSet.transformWithState(new InitialStateInMemoryTestClass(), - TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf) + TTLMode.NoTTL(), TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf) testStream(query, OutputMode.Update())( // non-exist key test @@ -232,7 +235,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest val query = inputData.toDS() .groupByKey(x => x.key) .transformWithState(new AccumulateStatefulProcessorWithInitState(), - TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf + TTLMode.NoTTL(), TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf ) testStream(query, OutputMode.Update())( AddData(inputData, InitInputRow("init_1", "add", 50.0)), @@ -252,6 +255,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest val result = inputData.toDS() .groupByKey(x => x.key) .transformWithState(new AccumulateStatefulProcessorWithInitState(), + TTLMode.NoTTL(), TimeoutMode.NoTimeouts(), OutputMode.Append(), createInitialDfForTest) @@ -270,6 +274,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest val query = inputData.toDS() .groupByKey(x => x.key) .transformWithState(new AccumulateStatefulProcessorWithInitState(), + TTLMode.NoTTL(), TimeoutMode.NoTimeouts(), OutputMode.Append(), initDf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 9b193276aa39..3f21c50abae4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -798,7 +798,7 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest { val result = inputData.toDS() .groupByKey(x => x.key) .transformWithState(new AccumulateStatefulProcessorWithInitState(), - TimeoutMode.NoTimeouts(), OutputMode.Append(), initDf + TTLMode.NoTTL(), TimeoutMode.NoTimeouts(), OutputMode.Append(), initDf ) testStream(result, OutputMode.Update())( AddData(inputData, InitInputRow("a", "add", -1.0)), From a25cbc38fe6b5f8934e83c1625e5782ed3998d12 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Thu, 28 Mar 2024 16:52:56 -0700 Subject: [PATCH 16/21] Modify eventTime ttl to use absolute time instead of a ttlDuration. --- .../main/resources/error/error-classes.json | 7 +++ docs/sql-error-conditions.md | 7 +++ .../spark/sql/streaming/ValueState.scala | 13 +++- .../execution/streaming/ValueStateImpl.scala | 15 +++++ .../streaming/ValueStateImplWithTTL.scala | 20 +++++- .../streaming/state/StateStoreErrors.scala | 12 ++++ .../TransformWithStateTTLSuite.scala | 61 +++++++++++-------- 7 files changed, 106 insertions(+), 29 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 9d2524751317..065385855431 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3565,6 +3565,13 @@ ], "sqlState" : "42802" }, + "STATEFUL_PROCESSOR_CANNOT_USE_TTL_DURATION_IN_EVENT_TIME_TTL_MODE" : { + "message" : [ + "TTL duration is not allowed for event time ttl expiration on State store operation= on state=.", + "Use absolute expiration time instead." + ], + "sqlState" : "42802" + }, "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : { "message" : [ "Failed to create column family with unsupported starting character and name=." diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 501938005d83..25f068a3d3a2 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -2174,6 +2174,13 @@ Failed to perform stateful processor operation=`` with invalid ti Cannot re-initialize state on the same grouping key during initial state handling for stateful processor. Invalid grouping key=``. +### STATEFUL_PROCESSOR_CANNOT_USE_TTL_DURATION_IN_EVENT_TIME_TTL_MODE + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +TTL duration is not allowed for event time ttl expiration on State store operation=`` on state=``. +Use absolute expiration time instead. + ### STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS [SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala index 11e69dd08bf1..95fc61980fff 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala @@ -46,11 +46,20 @@ private[sql] trait ValueState[S] extends Serializable { /** * Update the value of the state. * @param newState the new value - * @param ttlDuration set the ttl to current batch processing time (for processing time TTL mode) - * or current watermark (for event time ttl mode) plus ttlDuration + * @param ttlDuration set the ttl to current batch processing time + * (for processing time TTL mode) plus ttlDuration */ def update(newState: S, ttlDuration: Duration = Duration.ZERO): Unit + + /** + * Update the value of the state. + * + * @param newState the new value + * @param expirationMs set the ttl to expirationMs (processingTime or eventTime) + */ + def update(newState: S, expirationMs: Long): Unit + /** Remove this state. */ def clear(): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 964098243b03..fa250d659137 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -90,6 +90,21 @@ class ValueStateImpl[S]( encodedValue, stateName) } + /** Function to update and overwrite state associated with given key */ + override def update( + newState: S, + expirationMs: Long): Unit = { + + if (expirationMs != -1) { + throw StateStoreErrors.cannotProvideTTLDurationForNoTTLMode("update", stateName) + } + + val encodedValue = stateTypesEncoder.encodeValue(newState) + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey), + encodedValue, stateName) + } + /** Function to remove state for given key */ override def clear(): Unit = { store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index 17145c1cfa02..77e8e72a70ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{TTLMode, ValueState} /** @@ -102,6 +102,11 @@ class ValueStateImplWithTTL[S]( newState: S, ttlDuration: Duration = Duration.ZERO): Unit = { + if (ttlMode == TTLMode.EventTimeTTL() && ttlDuration != Duration.ZERO) { + throw StateStoreErrors.cannotProvideTTLDurationForEventTimeTTLMode( + "update", stateName) + } + val expirationMs = if (ttlDuration != null && ttlDuration != Duration.ZERO) { StateTTL.calculateExpirationTimeForDuration( @@ -119,6 +124,19 @@ class ValueStateImplWithTTL[S]( ttlState.upsertTTLForStateKey(expirationMs, serializedGroupingKey) } + override def update( + newState: S, + expirationMs: Long): Unit = { + + val encodedValue = stateTypesEncoder.encodeValue(newState, expirationMs) + + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey), + encodedValue, stateName) + + ttlState.upsertTTLForStateKey(expirationMs, serializedGroupingKey) + } + /** Function to remove state for given key */ override def clear(): Unit = { store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index ad16063653e7..6ee7382d7bd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -129,6 +129,11 @@ object StateStoreErrors { stateName: String): StatefulProcessorCannotAssignTTLInNoTTLMode = { new StatefulProcessorCannotAssignTTLInNoTTLMode(operationType, stateName) } + + def cannotProvideTTLDurationForEventTimeTTLMode(operationType: String, + stateName: String): StatefulProcessorCannotUseTTLDurationInEventTimeTTLMode = { + new StatefulProcessorCannotUseTTLDurationInEventTimeTTLMode(operationType, stateName) + } } class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String) @@ -211,3 +216,10 @@ class StatefulProcessorCannotAssignTTLInNoTTLMode( extends SparkUnsupportedOperationException( errorClass = "STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE", messageParameters = Map("operationType" -> operationType, "stateName" -> stateName)) + +class StatefulProcessorCannotUseTTLDurationInEventTimeTTLMode( + operationType: String, + stateName: String) + extends SparkUnsupportedOperationException( + errorClass = "STATEFUL_PROCESSOR_CANNOT_USE_TTL_DURATION_IN_EVENT_TIME_TTL_MODE", + messageParameters = Map("operationType" -> operationType, "stateName" -> stateName)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala index 75d98ce0933b..112c809d97f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.streaming import java.sql.Timestamp -import java.time.{Duration, Instant} -import java.time.temporal.ChronoUnit +import java.time.Duration import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders @@ -33,7 +32,8 @@ case class InputEvent( action: String, value: Int, ttl: Duration, - eventTime: Timestamp = null) + eventTime: Timestamp = null, + eventTimeTtl: Timestamp = null) case class OutputEvent( key: String, @@ -43,6 +43,7 @@ case class OutputEvent( object TTLInputProcessFunction { def processRow( + ttlMode: TTLMode, row: InputEvent, valueState: ValueStateImplWithTTL[Int]): Iterator[OutputEvent] = { var results = List[OutputEvent]() @@ -63,7 +64,13 @@ object TTLInputProcessFunction { results = OutputEvent(key, -1, isTTLValue = true, ttlExpiration.get) :: results } } else if (row.action == "put") { - valueState.update(row.value, row.ttl) + if (ttlMode == TTLMode.EventTimeTTL() && row.eventTimeTtl != null) { + valueState.update(row.value, row.eventTimeTtl.getTime) + } else if (ttlMode == TTLMode.EventTimeTTL()) { + valueState.update(row.value) + } else { + valueState.update(row.value, row.ttl) + } } else if (row.action == "get_values_in_ttl_state") { val ttlValues = valueState.getValuesInTTLState() ttlValues.foreach { v => @@ -80,7 +87,7 @@ class ValueStateTTLProcessor with Logging { @transient private var _valueState: ValueStateImplWithTTL[Int] = _ - private var _ttlMode: TTLMode = _ + @transient private var _ttlMode: TTLMode = _ override def init( outputMode: OutputMode, @@ -100,7 +107,7 @@ class ValueStateTTLProcessor var results = List[OutputEvent]() for (row <- inputRows) { - val resultIter = TTLInputProcessFunction.processRow(row, _valueState) + val resultIter = TTLInputProcessFunction.processRow(_ttlMode, row, _valueState) resultIter.foreach { r => results = r :: results } @@ -118,6 +125,7 @@ case class MultipleValueStatesTTLProcessor( @transient private var _valueStateWithTTL: ValueStateImplWithTTL[Int] = _ @transient private var _valueStateWithoutTTL: ValueStateImplWithTTL[Int] = _ + @transient private var _ttlMode: TTLMode = _ override def init( outputMode: OutputMode, @@ -129,6 +137,7 @@ case class MultipleValueStatesTTLProcessor( _valueStateWithoutTTL = getHandle .getValueState("valueState", Encoders.scalaInt) .asInstanceOf[ValueStateImplWithTTL[Int]] + _ttlMode = ttlMode } override def handleInputRows( @@ -144,7 +153,7 @@ case class MultipleValueStatesTTLProcessor( } for (row <- inputRows) { - val resultIterator = TTLInputProcessFunction.processRow(row, state) + val resultIterator = TTLInputProcessFunction.processRow(_ttlMode, row, state) resultIterator.foreach { r => results = r :: results } @@ -394,15 +403,15 @@ class TransformWithStateTTLSuite TimeoutMode.NoTimeouts(), TTLMode.EventTimeTTL()) - val eventTime1 = Timestamp.from(Instant.EPOCH.plus(30, ChronoUnit.SECONDS)) - val eventTime2 = Timestamp.from(Instant.EPOCH.plus(60, ChronoUnit.SECONDS)) - val ttlDuration = Duration.ofMinutes(1) - val ttlExpirationMs = Instant.EPOCH.plus(ttlDuration).toEpochMilli - val eventTime3 = Timestamp.from(Instant.ofEpochMilli(ttlExpirationMs + 1000)) + val eventTime1 = Timestamp.valueOf("2024-01-01 00:00:00") + val eventTime2 = Timestamp.valueOf("2024-01-01 00:02:00") + val ttlExpiration = Timestamp.valueOf("2024-01-01 00:03:00") + val ttlExpirationMs = ttlExpiration.getTime + val eventTime3 = Timestamp.valueOf("2024-01-01 00:05:00") testStream(result)( AddData(inputStream, - InputEvent("k1", "put", 1, ttlDuration, eventTime1)), + InputEvent("k1", "put", 1, null, eventTime1, ttlExpiration)), CheckNewAnswer(), // get this state, and make sure we get unexpired value AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime2)), @@ -416,7 +425,7 @@ class TransformWithStateTTLSuite ProcessAllAvailable(), CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), // increment event time so that key k1 expires - AddData(inputStream, InputEvent("k2", "put", 1, ttlDuration, eventTime3)), + AddData(inputStream, InputEvent("k2", "put", 1, null, eventTime3)), CheckNewAnswer(), // validate that k1 has expired AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime3)), @@ -443,15 +452,15 @@ class TransformWithStateTTLSuite TimeoutMode.NoTimeouts(), TTLMode.EventTimeTTL()) - val eventTime1 = Timestamp.from(Instant.EPOCH.plus(30, ChronoUnit.SECONDS)) - val eventTime2 = Timestamp.from(Instant.EPOCH.plus(60, ChronoUnit.SECONDS)) - val ttlDuration = Duration.ofMinutes(1) - val ttlExpirationMs = Instant.EPOCH.plus(ttlDuration).toEpochMilli - val eventTime3 = Timestamp.from(Instant.ofEpochMilli(ttlExpirationMs + 1000)) + val eventTime1 = Timestamp.valueOf("2024-01-01 00:00:00") + val eventTime2 = Timestamp.valueOf("2024-01-01 00:02:00") + val ttlExpiration = Timestamp.valueOf("2024-01-01 00:03:00") + val ttlExpirationMs = ttlExpiration.getTime + val eventTime3 = Timestamp.valueOf("2024-01-01 00:05:00") testStream(result)( AddData(inputStream, - InputEvent("k1", "put", 1, ttlDuration, eventTime1)), + InputEvent("k1", "put", 1, null, eventTime1, ttlExpiration)), CheckNewAnswer(), // get this state, and make sure we get unexpired value AddData(inputStream, InputEvent("k1", "get", 1, null, eventTime2)), @@ -493,14 +502,14 @@ class TransformWithStateTTLSuite TimeoutMode.NoTimeouts(), TTLMode.EventTimeTTL()) - val eventTime1 = Timestamp.from(Instant.EPOCH.plus(30, ChronoUnit.SECONDS)) - val eventTime2 = Timestamp.from(Instant.EPOCH.plus(60, ChronoUnit.SECONDS)) - val ttlDuration = Duration.ofMinutes(1) - val ttlExpirationMs = Instant.EPOCH.plus(ttlDuration).toEpochMilli - val eventTime3 = Timestamp.from(Instant.ofEpochMilli(ttlExpirationMs + 1000)) + val eventTime1 = Timestamp.valueOf("2024-01-01 00:00:00") + val eventTime2 = Timestamp.valueOf("2024-01-01 00:02:00") + val ttlExpiration = Timestamp.valueOf("2024-01-01 00:03:00") + val ttlExpirationMs = ttlExpiration.getTime + val eventTime3 = Timestamp.valueOf("2024-01-01 00:05:00") testStream(result)( - AddData(inputStream, InputEvent("k1", "put", 1, ttlDuration, eventTime1)), + AddData(inputStream, InputEvent("k1", "put", 1, null, eventTime1, ttlExpiration)), CheckNewAnswer(), // get this state, and make sure we get unexpired value AddData(inputStream, InputEvent("k1", "get", 1, null, eventTime2)), From e2be0133efa967fe9489c1c51835ea04569d11d0 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 29 Mar 2024 10:56:47 -0700 Subject: [PATCH 17/21] init --- .../spark/sql/streaming/ListState.scala | 8 +- .../execution/streaming/ListStateImpl.scala | 8 +- .../streaming/ListStateImplWithTTL.scala | 322 ++++++++++++++++ .../StatefulProcessorHandleImpl.scala | 25 +- ....scala => TransformWithStateTTLTest.scala} | 360 ++++++++++++------ 5 files changed, 605 insertions(+), 118 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala rename sql/core/src/test/scala/org/apache/spark/sql/streaming/{TransformWithStateTTLSuite.scala => TransformWithStateTTLTest.scala} (72%) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala index 0e2d6cc3778c..c3d1ebca4a65 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.streaming +import java.time.Duration + import org.apache.spark.annotation.{Evolving, Experimental} @Experimental @@ -33,13 +35,13 @@ private[sql] trait ListState[S] extends Serializable { def get(): Iterator[S] /** Update the value of the list. */ - def put(newState: Array[S]): Unit + def put(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit /** Append an entry to the list */ - def appendValue(newState: S): Unit + def appendValue(newState: S, ttlDuration: Duration = Duration.ZERO): Unit /** Append an entire list to the existing value */ - def appendList(newState: Array[S]): Unit + def appendList(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit /** Removes this state for the given grouping key. */ def clear(): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 56c9d2664d9e..3a7ed530fad4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.streaming +import java.time.Duration + import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -74,7 +76,7 @@ class ListStateImpl[S]( } /** Update the value of the list. */ - override def put(newState: Array[S]): Unit = { + override def put(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit = { validateNewState(newState) val encodedKey = stateTypesEncoder.encodeGroupingKey() @@ -92,14 +94,14 @@ class ListStateImpl[S]( } /** Append an entry to the list. */ - override def appendValue(newState: S): Unit = { + override def appendValue(newState: S, ttlDuration: Duration = Duration.ZERO): Unit = { StateStoreErrors.requireNonNullStateValue(newState, stateName) store.merge(stateTypesEncoder.encodeGroupingKey(), stateTypesEncoder.encodeValue(newState), stateName) } /** Append an entire list to the existing value. */ - override def appendList(newState: Array[S]): Unit = { + override def appendList(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit = { validateNewState(newState) val encodedKey = stateTypesEncoder.encodeGroupingKey() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala new file mode 100644 index 000000000000..cf937b8fe99d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.time.Duration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.streaming.{ListState, TTLMode} + +/** + * Provides concrete implementation for list of values associated with a state variable + * used in the streaming transformWithState operator. + * + * @param store - reference to the StateStore instance to be used for storing state + * @param stateName - name of logical state partition + * @param keyEnc - Spark SQL encoder for key + * @param valEncoder - Spark SQL encoder for value + * @tparam S - data type of object that will be stored in the list + */ +class ListStateImplWithTTL[S]( + store: StateStore, + stateName: String, + keyExprEnc: ExpressionEncoder[Any], + valEncoder: Encoder[S], + ttlMode: TTLMode, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]) + extends ListState[S] + with Logging + with StateVariableWithTTLSupport { + + private val keySerializer = keyExprEnc.createSerializer() + private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) + private[sql] var ttlState: SingleKeyTTLStateImpl = _ + + initialize() + private def initialize(): Unit = { + store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true) + + if (ttlMode != TTLMode.NoTTL()) { + ttlState = new SingleKeyTTLStateImpl(ttlMode, stateName, store, + batchTimestampMs, eventTimeWatermarkMs) + } + } + /** Whether state exists or not. */ + override def exists(): Boolean = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val stateValue = store.get(encodedGroupingKey, stateName) + stateValue != null + } + + /** + * Get the state value if it exists. If the state does not exist in state store, an + * empty iterator is returned. + */ + override def get(): Iterator[S] = { + + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName) + var currentRow: UnsafeRow = null + + new Iterator[S] { + override def hasNext: Boolean = { + if (currentRow == null) { + setNextValidRow() + } + logError(s"### hasNext: currentRow is null = ${currentRow == null}") + currentRow != null + } + + override def next(): S = { + if (currentRow == null) { + setNextValidRow() + } + if (currentRow == null) { + throw new NoSuchElementException("Iterator is at the end") + } + logError(s"### in get for ListState") + val result = stateTypesEncoder.decodeValue(currentRow) + currentRow = null + logError(s"### result is null ${result == null}") + result + } + + // sets currentRow to a valid state, where we are + // pointing to a non-expired row + private def setNextValidRow(): Unit = { + assert(currentRow == null) + logError("### at the top of setNextValidRow") + if (unsafeRowValuesIterator.hasNext) { + currentRow = unsafeRowValuesIterator.next() + } else { + currentRow = null + return + } + while (unsafeRowValuesIterator.hasNext && (currentRow == null || isExpired(currentRow))) { + // log each of the conditions at the top of the while loop + logError(s"### unsafeRowValuesIterator.hasNext = ${unsafeRowValuesIterator.hasNext}," + + s" currentRow is null = ${currentRow == null}, isExpired = ${isExpired(currentRow)}") + currentRow = unsafeRowValuesIterator.next() + } + // in this case, we have iterated to the end, and there are no + // non-expired values + if (currentRow != null && isExpired(currentRow)) { + currentRow = null + } + logError(s"### setNextValidRow: currentRow is null = ${currentRow == null}") + } + } + } + + /** Update the value of the list. */ + override def put(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit = { + validateNewState(newState) + logError(s"### in put for ListState") + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + var isFirst = true + + var expirationMs: Long = -1 + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + expirationMs = StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } + + newState.foreach { v => + val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) + logError(s"### in put loop for ListState") + if (isFirst) { + store.put(encodedKey, encodedValue, stateName) + isFirst = false + } else { + store.merge(encodedKey, encodedValue, stateName) + } + } + ttlState.upsertTTLForStateKey(expirationMs, + serializedGroupingKey) + } + + /** Append an entry to the list. */ + override def appendValue(newState: S, ttlDuration: Duration = Duration.ZERO): Unit = { + StateStoreErrors.requireNonNullStateValue(newState, stateName) + var expirationMs: Long = -1 + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + expirationMs = StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } + val encodedValue = stateTypesEncoder.encodeValue(newState, expirationMs) + store.merge(stateTypesEncoder.encodeGroupingKey(), + encodedValue, stateName) + ttlState.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey()) + } + + /** Append an entire list to the existing value. */ + override def appendList(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit = { + validateNewState(newState) + + var expirationMs: Long = -1 + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + expirationMs = StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } + + val encodedKey = stateTypesEncoder.encodeGroupingKey() + newState.foreach { v => + val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) + store.merge(encodedKey, encodedValue, stateName) + } + ttlState.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey()) + } + + /** Remove this state. */ + override def clear(): Unit = { + store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + } + + private def validateNewState(newState: Array[S]): Unit = { + StateStoreErrors.requireNonNullStateValue(newState, stateName) + StateStoreErrors.requireNonEmptyListStateValue(newState, stateName) + + newState.foreach { v => + StateStoreErrors.requireNonNullStateValue(v, stateName) + } + } + + /** + * Clears the user state associated with this grouping key + * if it has expired. This function is called by Spark to perform + * cleanup at the end of transformWithState processing. + * + * Spark uses a secondary index to determine if the user state for + * this grouping key has expired. However, its possible that the user + * has updated the TTL and secondary index is out of date. Implementations + * must validate that the user State has actually expired before cleanup based + * on their own State data. + * + * @param groupingKey grouping key for which cleanup should be performed. + */ + override def clearIfExpired(groupingKey: Array[Byte]): Unit = { + val encodedGroupingKey = stateTypesEncoder.encodeSerializedGroupingKey(groupingKey) + val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, stateName) + // We clear the list, and use the iterator to put back all of the non-expired values + store.remove(encodedGroupingKey, stateName) + var isFirst = true + unsafeRowValuesIterator.foreach { encodedValue => + if (!isExpired(encodedValue)) { + if (isFirst) { + store.put(encodedGroupingKey, encodedValue, stateName) + isFirst = false + } else { + store.merge(encodedGroupingKey, encodedValue, stateName) + } + } + } + } + + private def isExpired(valueRow: UnsafeRow): Boolean = { + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow) + val isExpired = expirationMs.map( + StateTTL.isExpired(ttlMode, _, batchTimestampMs, eventTimeWatermarkMs)) + + isExpired.isDefined && isExpired.get + } + + /* + * Internal methods to probe state for testing. The below methods exist for unit tests + * to read the state ttl values, and ensure that values are persisted correctly in + * the underlying state store. + */ + + /** + * Retrieves the value from State even if its expired. This method is used + * in tests to read the state store value, and ensure if its cleaned up at the + * end of the micro-batch. + */ + private[sql] def getWithoutEnforcingTTL(): Iterator[S] = { + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName) + new Iterator[S] { + override def hasNext: Boolean = { + unsafeRowValuesIterator.hasNext + } + override def next(): S = { + val valueUnsafeRow = unsafeRowValuesIterator.next() + stateTypesEncoder.decodeValue(valueUnsafeRow) + } + } + } + + /** + * Read the ttl value associated with the grouping key. + */ + private[sql] def getTTLValues(): Iterator[Option[Long]] = { + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName) + new Iterator[Option[Long]] { + override def hasNext: Boolean = { + unsafeRowValuesIterator.hasNext + } + + override def next(): Option[Long] = { + val valueUnsafeRow = unsafeRowValuesIterator.next() + stateTypesEncoder.decodeTtlExpirationMs(valueUnsafeRow) + } + } + } + + /** + * Get all ttl values stored in ttl state for current implicit + * grouping key. + */ + private[sql] def getValuesInTTLState(): Iterator[Long] = { + + val ttlIterator = ttlState.iterator() + val implicitGroupingKey = stateTypesEncoder.serializeGroupingKey() + var nextValue: Option[Long] = None + + new Iterator[Long] { + override def hasNext: Boolean = { + while (nextValue.isEmpty && ttlIterator.hasNext) { + val nextTtlValue = ttlIterator.next() + val groupingKey = nextTtlValue.groupingKey + + if (groupingKey sameElements implicitGroupingKey) { + nextValue = Some(nextTtlValue.expirationMs) + } + } + + nextValue.isDefined + } + + override def next(): Long = { + val result = nextValue.get + nextValue = None + + result + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 1c70e676bfd4..234d87f183c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -138,6 +138,25 @@ class StatefulProcessorHandleImpl( } } + override def getListState[T]( + stateName: String, + valEncoder: Encoder[T]): ListState[T] = { + verifyStateVarOperations("get_list_state") + + if (ttlMode == TTLMode.NoTTL()) { + new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) + } else { + val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, + keyEncoder, valEncoder, ttlMode, batchTimestampMs, eventTimeWatermarkMs) + + val ttlState = listStateWithTTL.ttlState + ttlState.setStateVariable(listStateWithTTL) + ttlStates.add(ttlState) + + listStateWithTTL + } + } + override def getQueryInfo(): QueryInfo = currQueryInfo private lazy val timerState = new TimerStateImpl(store, timeoutMode, keyEncoder) @@ -222,12 +241,6 @@ class StatefulProcessorHandleImpl( store.removeColFamilyIfExists(stateName) } - override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { - verifyStateVarOperations("get_list_state") - val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) - resultState - } - override def getMapState[K, V]( stateName: String, userKeyEnc: Encoder[K], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala similarity index 72% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala index 112c809d97f4..19de76e92f73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala @@ -22,7 +22,7 @@ import java.time.Duration import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders -import org.apache.spark.sql.execution.streaming.{MemoryStream, ValueStateImplWithTTL} +import org.apache.spark.sql.execution.streaming.{ListStateImplWithTTL, MemoryStream, ValueStateImplWithTTL} import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -80,6 +80,40 @@ object TTLInputProcessFunction { results.iterator } + + def processRow( + row: InputEvent, + listState: ListStateImplWithTTL[Int]): Iterator[OutputEvent] = { + val key = row.key + var results = List[OutputEvent]() + if (row.action == "get") { + val currState = listState.get() + currState.foreach { v => + results = OutputEvent(key, v, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_without_enforcing_ttl") { + val currState = listState.getWithoutEnforcingTTL() + currState.foreach { v => + results = OutputEvent(key, v, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_ttl_value_from_state") { + val ttlExpirations = listState.getTTLValues() + // for all values ttlExpiration for which isDefined is true, add to results + ttlExpirations.filter(_.isDefined).foreach { expiry => + results = OutputEvent(key, -1, isTTLValue = true, expiry.get) :: results + } + } else if (row.action == "put") { + listState.put(Array(row.value), row.ttl) + } else if (row.action == "append") { + listState.appendValue(row.value, row.ttl) + } else if (row.action == "get_values_in_ttl_state") { + val ttlValues = listState.getValuesInTTLState() + ttlValues.foreach { v => + results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: results + } + } + results.iterator + } } class ValueStateTTLProcessor @@ -118,8 +152,8 @@ class ValueStateTTLProcessor } case class MultipleValueStatesTTLProcessor( - ttlKey: String, - noTtlKey: String) + ttlKey: String, + noTtlKey: String) extends StatefulProcessor[String, InputEvent, OutputEvent] with Logging { @@ -128,9 +162,9 @@ case class MultipleValueStatesTTLProcessor( @transient private var _ttlMode: TTLMode = _ override def init( - outputMode: OutputMode, - timeoutMode: TimeoutMode, - ttlMode: TTLMode): Unit = { + outputMode: OutputMode, + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _valueStateWithTTL = getHandle .getValueState("valueState", Encoders.scalaInt) .asInstanceOf[ValueStateImplWithTTL[Int]] @@ -141,10 +175,10 @@ case class MultipleValueStatesTTLProcessor( } override def handleInputRows( - key: String, - inputRows: Iterator[InputEvent], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + key: String, + inputRows: Iterator[InputEvent], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { var results = List[OutputEvent]() val state = if (key == ttlKey) { _valueStateWithTTL @@ -162,10 +196,47 @@ case class MultipleValueStatesTTLProcessor( } } -class TransformWithStateTTLSuite +class ListStateTTLProcessor + extends StatefulProcessor[String, InputEvent, OutputEvent] + with Logging { + + @transient private var _listState: ListStateImplWithTTL[Int] = _ + @transient private var _ttlMode: TTLMode = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { + _listState = getHandle + .getListState("listState", Encoders.scalaInt) + .asInstanceOf[ListStateImplWithTTL[Int]] + _ttlMode = ttlMode + } + + override def handleInputRows( + key: String, + inputRows: Iterator[InputEvent], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + + for (row <- inputRows) { + val resultIter = TTLInputProcessFunction.processRow(row, _listState) + resultIter.foreach { r => + results = r :: results + } + } + + results.iterator + } +} + +abstract class TransformWithStateTTLTest extends StreamTest { import testImplicits._ + def getProcessor(): StatefulProcessor[String, InputEvent, OutputEvent] + test("validate state is evicted at ttl expiry - processing time ttl") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -174,7 +245,7 @@ class TransformWithStateTTLSuite val result = inputStream.toDS() .groupByKey(x => x.key) .transformWithState( - new ValueStateTTLProcessor(), + getProcessor(), TimeoutMode.NoTimeouts(), TTLMode.ProcessingTimeTTL()) @@ -221,7 +292,7 @@ class TransformWithStateTTLSuite val result = inputStream.toDS() .groupByKey(x => x.key) .transformWithState( - new ValueStateTTLProcessor(), + getProcessor(), TimeoutMode.NoTimeouts(), TTLMode.ProcessingTimeTTL()) @@ -283,7 +354,7 @@ class TransformWithStateTTLSuite val result = inputStream.toDS() .groupByKey(x => x.key) .transformWithState( - new ValueStateTTLProcessor(), + getProcessor(), TimeoutMode.NoTimeouts(), TTLMode.ProcessingTimeTTL()) @@ -334,62 +405,6 @@ class TransformWithStateTTLSuite } } - test("validate multiple value states - processing time ttl") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - val ttlKey = "k1" - val noTtlKey = "k2" - - val inputStream = MemoryStream[InputEvent] - val result = inputStream.toDS() - .groupByKey(x => x.key) - .transformWithState( - MultipleValueStatesTTLProcessor(ttlKey, noTtlKey), - TimeoutMode.NoTimeouts(), - TTLMode.ProcessingTimeTTL()) - - val clock = new StreamManualClock - testStream(result)( - StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), - AddData(inputStream, InputEvent(ttlKey, "put", 1, Duration.ofMinutes(1))), - AddData(inputStream, InputEvent(noTtlKey, "put", 2, Duration.ZERO)), - // advance clock to trigger processing - AdvanceManualClock(1 * 1000), - CheckNewAnswer(), - // get both state values, and make sure we get unexpired value - AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), - AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), - AdvanceManualClock(1 * 1000), - CheckNewAnswer( - OutputEvent(ttlKey, 1, isTTLValue = false, -1), - OutputEvent(noTtlKey, 2, isTTLValue = false, -1) - ), - // ensure ttl values were added correctly, and noTtlKey has no ttl values - AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), - AddData(inputStream, InputEvent(noTtlKey, "get_ttl_value_from_state", -1, null)), - AdvanceManualClock(1 * 1000), - CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), - AddData(inputStream, InputEvent(ttlKey, "get_values_in_ttl_state", -1, null)), - AddData(inputStream, InputEvent(noTtlKey, "get_values_in_ttl_state", -1, null)), - AdvanceManualClock(1 * 1000), - CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), - // advance clock after expiry - AdvanceManualClock(60 * 1000), - AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), - AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), - // advance clock to trigger processing - AdvanceManualClock(1 * 1000), - // validate ttlKey is expired, bot noTtlKey is still present - CheckNewAnswer(OutputEvent(noTtlKey, 2, isTTLValue = false, -1)), - // validate ttl value is removed in the value state column family - AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), - AdvanceManualClock(1 * 1000), - CheckNewAnswer() - ) - } - } - test("validate state is evicted at ttl expiry - event time ttl") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -399,7 +414,7 @@ class TransformWithStateTTLSuite .withWatermark("eventTime", "1 second") .groupByKey(x => x.key) .transformWithState( - new ValueStateTTLProcessor(), + getProcessor(), TimeoutMode.NoTimeouts(), TTLMode.EventTimeTTL()) @@ -448,7 +463,7 @@ class TransformWithStateTTLSuite .withWatermark("eventTime", "1 second") .groupByKey(x => x.key) .transformWithState( - new ValueStateTTLProcessor(), + getProcessor(), TimeoutMode.NoTimeouts(), TTLMode.EventTimeTTL()) @@ -488,56 +503,189 @@ class TransformWithStateTTLSuite ) } } +} + +class ValueStateTTLSuite extends TransformWithStateTTLTest { + import testImplicits._ + + override def getProcessor(): StatefulProcessor[String, InputEvent, OutputEvent] = { + new ValueStateTTLProcessor() + } - test("validate ttl removal keeps value in state - event time ttl") { + test("validate multiple value states - with and without ttl - processing time ttl") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val ttlKey = "k1" + val noTtlKey = "k2" + val inputStream = MemoryStream[InputEvent] val result = inputStream.toDS() - .withWatermark("eventTime", "1 second") .groupByKey(x => x.key) .transformWithState( - new ValueStateTTLProcessor(), + MultipleValueStatesTTLProcessor(ttlKey, noTtlKey), TimeoutMode.NoTimeouts(), - TTLMode.EventTimeTTL()) - - val eventTime1 = Timestamp.valueOf("2024-01-01 00:00:00") - val eventTime2 = Timestamp.valueOf("2024-01-01 00:02:00") - val ttlExpiration = Timestamp.valueOf("2024-01-01 00:03:00") - val ttlExpirationMs = ttlExpiration.getTime - val eventTime3 = Timestamp.valueOf("2024-01-01 00:05:00") + TTLMode.ProcessingTimeTTL()) + val clock = new StreamManualClock testStream(result)( - AddData(inputStream, InputEvent("k1", "put", 1, null, eventTime1, ttlExpiration)), + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent(ttlKey, "put", 1, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent(noTtlKey, "put", 2, Duration.ZERO)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), CheckNewAnswer(), - // get this state, and make sure we get unexpired value - AddData(inputStream, InputEvent("k1", "get", 1, null, eventTime2)), - CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), - // ensure ttl values were added correctly - AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), - CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), - AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), - CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), - // update state and remove ttl - AddData(inputStream, InputEvent("k1", "put", 2, null, eventTime2)), - AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime2)), - // validate value is not expired - CheckNewAnswer(OutputEvent("k1", 2, isTTLValue = false, -1)), + // get both state values, and make sure we get unexpired value + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer( + OutputEvent(ttlKey, 1, isTTLValue = false, -1), + OutputEvent(noTtlKey, 2, isTTLValue = false, -1) + ), + // ensure ttl values were added correctly, and noTtlKey has no ttl values + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent(ttlKey, "get_values_in_ttl_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + // advance clock after expiry + AdvanceManualClock(60 * 1000), + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + // validate ttlKey is expired, bot noTtlKey is still present + CheckNewAnswer(OutputEvent(noTtlKey, 2, isTTLValue = false, -1)), // validate ttl value is removed in the value state column family - AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + + test("validate multiple value states - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val ttlKey = "k1" + val noTtlKey = "k2" + + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + MultipleValueStatesTTLProcessor(ttlKey, noTtlKey), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent(ttlKey, "put", 1, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent(noTtlKey, "put", 2, Duration.ZERO)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), CheckNewAnswer(), - // validate ttl state still has old ttl value present - AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime3)), - CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), - // eventTime has been advanced to eventTim3 which is after older expiration value - // ensure unexpired value is still present in the state - AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime3)), - CheckNewAnswer(OutputEvent("k1", 2, isTTLValue = false, -1)), - // validate that the older expiration value is removed from ttl state - AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + // get both state values, and make sure we get unexpired value + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer( + OutputEvent(ttlKey, 1, isTTLValue = false, -1), + OutputEvent(noTtlKey, 2, isTTLValue = false, -1) + ), + // ensure ttl values were added correctly, and noTtlKey has no ttl values + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent(ttlKey, "get_values_in_ttl_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + // advance clock after expiry + AdvanceManualClock(60 * 1000), + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + // validate ttlKey is expired, bot noTtlKey is still present + CheckNewAnswer(OutputEvent(noTtlKey, 2, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), CheckNewAnswer() ) } } } + +class ListStateTTLSuite extends TransformWithStateTTLTest { + + import testImplicits._ + + override def getProcessor(): StatefulProcessor[String, InputEvent, OutputEvent] = { + new ListStateTTLProcessor() + } + + test("verify iterator works with expired values in middle of list - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + getProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + // Add three elements with duration of a minute + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + AdvanceManualClock(1 * 1000), + AddData(inputStream, InputEvent("k1", "append", 2, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent("k1", "append", 3, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // Add three elements with a duration of 15 seconds + AddData(inputStream, InputEvent("k1", "append", 4, Duration.ofSeconds(15))), + AddData(inputStream, InputEvent("k1", "append", 5, Duration.ofSeconds(15))), + AddData(inputStream, InputEvent("k1", "append", 6, Duration.ofSeconds(15))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // Add three elements with a duration of a minute + AddData(inputStream, InputEvent("k1", "append", 7, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent("k1", "append", 8, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent("k1", "append", 9, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // Advance clock to expire the middle three elements + AdvanceManualClock(30 * 1000), + // Get all elements in the list + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + // Validate that the expired elements are not returned + CheckNewAnswer( + OutputEvent("k1", 1, isTTLValue = false, -1), + OutputEvent("k1", 2, isTTLValue = false, -1), + OutputEvent("k1", 3, isTTLValue = false, -1), + OutputEvent("k1", 7, isTTLValue = false, -1), + OutputEvent("k1", 8, isTTLValue = false, -1), + OutputEvent("k1", 9, isTTLValue = false, -1) + ) + ) + } + } +} From 2210994ab26194d510c04ef75cc9836fb7b866cd Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 29 Mar 2024 11:26:30 -0700 Subject: [PATCH 18/21] liststateimpl changes --- .../streaming/ListStateImplWithTTL.scala | 46 +++++++++++++++++++ .../streaming/TransformWithStateTTLTest.scala | 29 ++++++++---- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index cf937b8fe99d..e88d46be01f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -157,6 +157,28 @@ class ListStateImplWithTTL[S]( serializedGroupingKey) } + /** Update the value of the list. */ + def put(newState: Array[S], expirationMs: Long): Unit = { + validateNewState(newState) + logError(s"### in put for ListState") + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + var isFirst = true + + newState.foreach { v => + val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) + logError(s"### in put loop for ListState") + if (isFirst) { + store.put(encodedKey, encodedValue, stateName) + isFirst = false + } else { + store.merge(encodedKey, encodedValue, stateName) + } + } + ttlState.upsertTTLForStateKey(expirationMs, + serializedGroupingKey) + } + /** Append an entry to the list. */ override def appendValue(newState: S, ttlDuration: Duration = Duration.ZERO): Unit = { StateStoreErrors.requireNonNullStateValue(newState, stateName) @@ -191,6 +213,30 @@ class ListStateImplWithTTL[S]( stateTypesEncoder.serializeGroupingKey()) } + /** Append an entry to the list. */ + def appendValue(newState: S, expirationMs: Long): Unit = { + StateStoreErrors.requireNonNullStateValue(newState, stateName) + + val encodedValue = stateTypesEncoder.encodeValue(newState, expirationMs) + store.merge(stateTypesEncoder.encodeGroupingKey(), + encodedValue, stateName) + ttlState.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey()) + } + + /** Append an entire list to the existing value. */ + def appendList(newState: Array[S], expirationMs: Long): Unit = { + validateNewState(newState) + + val encodedKey = stateTypesEncoder.encodeGroupingKey() + newState.foreach { v => + val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) + store.merge(encodedKey, encodedValue, stateName) + } + ttlState.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey()) + } + /** Remove this state. */ override def clear(): Unit = { store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala index 19de76e92f73..775a0246dbac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala @@ -82,10 +82,11 @@ object TTLInputProcessFunction { } def processRow( + ttlMode: TTLMode, row: InputEvent, listState: ListStateImplWithTTL[Int]): Iterator[OutputEvent] = { - val key = row.key var results = List[OutputEvent]() + val key = row.key if (row.action == "get") { val currState = listState.get() currState.foreach { v => @@ -97,21 +98,33 @@ object TTLInputProcessFunction { results = OutputEvent(key, v, isTTLValue = false, -1) :: results } } else if (row.action == "get_ttl_value_from_state") { - val ttlExpirations = listState.getTTLValues() - // for all values ttlExpiration for which isDefined is true, add to results - ttlExpirations.filter(_.isDefined).foreach { expiry => - results = OutputEvent(key, -1, isTTLValue = true, expiry.get) :: results + val ttlExpiration = listState.getTTLValues() + ttlExpiration.filter(_.isDefined).foreach { v => + results = OutputEvent(key, -1, isTTLValue = false, v.get) :: results } } else if (row.action == "put") { - listState.put(Array(row.value), row.ttl) + if (ttlMode == TTLMode.EventTimeTTL() && row.eventTimeTtl != null) { + listState.put(Array(row.value), row.eventTimeTtl.getTime) + } else if (ttlMode == TTLMode.EventTimeTTL()) { + listState.put(Array(row.value)) + } else { + listState.put(Array(row.value), row.ttl) + } } else if (row.action == "append") { - listState.appendValue(row.value, row.ttl) + if (ttlMode == TTLMode.EventTimeTTL() && row.eventTimeTtl != null) { + listState.appendValue(row.value, row.eventTimeTtl.getTime) + } else if (ttlMode == TTLMode.EventTimeTTL()) { + listState.appendValue(row.value) + } else { + listState.appendValue(row.value, row.ttl) + } } else if (row.action == "get_values_in_ttl_state") { val ttlValues = listState.getValuesInTTLState() ttlValues.foreach { v => results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: results } } + results.iterator } } @@ -221,7 +234,7 @@ class ListStateTTLProcessor var results = List[OutputEvent]() for (row <- inputRows) { - val resultIter = TTLInputProcessFunction.processRow(row, _listState) + val resultIter = TTLInputProcessFunction.processRow(_ttlMode, row, _listState) resultIter.foreach { r => results = r :: results } From 7e4fa3db73f5d00b1b3d0b7087ea8b7574667f97 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 29 Mar 2024 11:52:52 -0700 Subject: [PATCH 19/21] setting expirationMs differently --- .../streaming/ListStateImplWithTTL.scala | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index e88d46be01f7..860f67b2fd51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -106,13 +106,14 @@ class ListStateImplWithTTL[S]( // pointing to a non-expired row private def setNextValidRow(): Unit = { assert(currentRow == null) - logError("### at the top of setNextValidRow") + logError(s"### at the top of setNextValidRow, hasNext = ${unsafeRowValuesIterator.hasNext}") if (unsafeRowValuesIterator.hasNext) { currentRow = unsafeRowValuesIterator.next() } else { currentRow = null return } + while (unsafeRowValuesIterator.hasNext && (currentRow == null || isExpired(currentRow))) { // log each of the conditions at the top of the while loop logError(s"### unsafeRowValuesIterator.hasNext = ${unsafeRowValuesIterator.hasNext}," + @@ -122,6 +123,7 @@ class ListStateImplWithTTL[S]( // in this case, we have iterated to the end, and there are no // non-expired values if (currentRow != null && isExpired(currentRow)) { + logError(s"### setting currentRow to null") currentRow = null } logError(s"### setNextValidRow: currentRow is null = ${currentRow == null}") @@ -137,11 +139,13 @@ class ListStateImplWithTTL[S]( val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() var isFirst = true - var expirationMs: Long = -1 - if (ttlDuration != null && ttlDuration != Duration.ZERO) { - expirationMs = StateTTL.calculateExpirationTimeForDuration( - ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) - } + val expirationMs = + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } else { + -1 + } newState.foreach { v => val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) @@ -182,11 +186,13 @@ class ListStateImplWithTTL[S]( /** Append an entry to the list. */ override def appendValue(newState: S, ttlDuration: Duration = Duration.ZERO): Unit = { StateStoreErrors.requireNonNullStateValue(newState, stateName) - var expirationMs: Long = -1 - if (ttlDuration != null && ttlDuration != Duration.ZERO) { - expirationMs = StateTTL.calculateExpirationTimeForDuration( - ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) - } + val expirationMs = + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } else { + -1 + } val encodedValue = stateTypesEncoder.encodeValue(newState, expirationMs) store.merge(stateTypesEncoder.encodeGroupingKey(), encodedValue, stateName) @@ -198,11 +204,13 @@ class ListStateImplWithTTL[S]( override def appendList(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit = { validateNewState(newState) - var expirationMs: Long = -1 - if (ttlDuration != null && ttlDuration != Duration.ZERO) { - expirationMs = StateTTL.calculateExpirationTimeForDuration( - ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) - } + val expirationMs = + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } else { + -1 + } val encodedKey = stateTypesEncoder.encodeGroupingKey() newState.foreach { v => From 2f7f9aabaa80095f49c031bb0ea68b35dd240528 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 29 Mar 2024 11:57:15 -0700 Subject: [PATCH 20/21] adding log --- .../spark/sql/execution/streaming/ListStateImplWithTTL.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index 860f67b2fd51..9c61089ffc48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -123,7 +123,7 @@ class ListStateImplWithTTL[S]( // in this case, we have iterated to the end, and there are no // non-expired values if (currentRow != null && isExpired(currentRow)) { - logError(s"### setting currentRow to null") + logError(s"### setting currentRow to null as it is expired") currentRow = null } logError(s"### setNextValidRow: currentRow is null = ${currentRow == null}") From 3a71c66b662ec933449124550bf0f8e34a967b31 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 29 Mar 2024 14:25:57 -0700 Subject: [PATCH 21/21] logs --- .../sql/execution/streaming/ListStateImplWithTTL.scala | 3 ++- .../apache/spark/sql/execution/streaming/TTLState.scala | 3 ++- .../sql/execution/streaming/ValueStateImplWithTTL.scala | 1 + .../spark/sql/streaming/TransformWithStateTTLTest.scala | 8 +++++++- 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index 9c61089ffc48..094318ae0c31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -109,6 +109,7 @@ class ListStateImplWithTTL[S]( logError(s"### at the top of setNextValidRow, hasNext = ${unsafeRowValuesIterator.hasNext}") if (unsafeRowValuesIterator.hasNext) { currentRow = unsafeRowValuesIterator.next() + return } else { currentRow = null return @@ -146,7 +147,7 @@ class ListStateImplWithTTL[S]( } else { -1 } - + logError(s"### listState expirationMs: ${expirationMs}") newState.foreach { v => val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) logError(s"### in put loop for ListState") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala index 127cb956628a..b60997bc801f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala @@ -155,7 +155,7 @@ class SingleKeyTTLStateImpl( /** * Helper methods for user State TTL. */ -object StateTTL { +object StateTTL extends Logging { def calculateExpirationTimeForDuration( ttlMode: TTLMode, ttlDuration: Duration, @@ -177,6 +177,7 @@ object StateTTL { batchTimestampMs: Option[Long], eventTimeWatermarkMs: Option[Long]): Boolean = { if (ttlMode == TTLMode.ProcessingTimeTTL()) { + logError(s"### batchTimestampMs: ${batchTimestampMs.get}, expirationMs: ${expirationMs}") batchTimestampMs.get >= expirationMs } else if (ttlMode == TTLMode.EventTimeTTL()) { eventTimeWatermarkMs.get >= expirationMs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index 77e8e72a70ca..92db728788e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -115,6 +115,7 @@ class ValueStateImplWithTTL[S]( -1 } + logError(s"### valueState expirationMs: ${expirationMs}") val encodedValue = stateTypesEncoder.encodeValue(newState, expirationMs) val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala index 775a0246dbac..7ea015b086ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala @@ -41,7 +41,7 @@ case class OutputEvent( isTTLValue: Boolean, ttlValue: Long) -object TTLInputProcessFunction { +object TTLInputProcessFunction extends Logging { def processRow( ttlMode: TTLMode, row: InputEvent, @@ -88,21 +88,25 @@ object TTLInputProcessFunction { var results = List[OutputEvent]() val key = row.key if (row.action == "get") { + logError(s"### get") val currState = listState.get() currState.foreach { v => results = OutputEvent(key, v, isTTLValue = false, -1) :: results } } else if (row.action == "get_without_enforcing_ttl") { + logError(s"### get without enforcing ttl") val currState = listState.getWithoutEnforcingTTL() currState.foreach { v => results = OutputEvent(key, v, isTTLValue = false, -1) :: results } } else if (row.action == "get_ttl_value_from_state") { + logError(s"### get ttl value from state") val ttlExpiration = listState.getTTLValues() ttlExpiration.filter(_.isDefined).foreach { v => results = OutputEvent(key, -1, isTTLValue = false, v.get) :: results } } else if (row.action == "put") { + logError(s"### put") if (ttlMode == TTLMode.EventTimeTTL() && row.eventTimeTtl != null) { listState.put(Array(row.value), row.eventTimeTtl.getTime) } else if (ttlMode == TTLMode.EventTimeTTL()) { @@ -111,6 +115,7 @@ object TTLInputProcessFunction { listState.put(Array(row.value), row.ttl) } } else if (row.action == "append") { + logError(s"### append") if (ttlMode == TTLMode.EventTimeTTL() && row.eventTimeTtl != null) { listState.appendValue(row.value, row.eventTimeTtl.getTime) } else if (ttlMode == TTLMode.EventTimeTTL()) { @@ -119,6 +124,7 @@ object TTLInputProcessFunction { listState.appendValue(row.value, row.ttl) } } else if (row.action == "get_values_in_ttl_state") { + logError(s"### get values in ttl state") val ttlValues = listState.getValuesInTTLState() ttlValues.foreach { v => results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: results