From ff53e576953cb8d098a07a3767aa6ed3e821be36 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Sun, 17 Mar 2024 21:28:07 -0700 Subject: [PATCH 01/16] Add support for ValueState TTL. --- .../main/resources/error/error-classes.json | 5 + ...ditions-unsupported-feature-error-class.md | 4 + .../apache/spark/sql/streaming/TTLMode.java | 51 +++++++ .../sql/catalyst/plans/logical/TTLMode.scala | 26 ++++ .../streaming/StatefulProcessorHandle.scala | 11 +- .../spark/sql/streaming/ValueState.scala | 4 +- .../sql/catalyst/plans/logical/object.scala | 5 +- .../spark/sql/KeyValueGroupedDataset.scala | 4 +- .../spark/sql/execution/SparkStrategies.scala | 7 +- .../execution/streaming/MapStateImpl.scala | 2 +- .../streaming/StateTypesEncoderUtils.scala | 42 +++++- .../streaming/StateVariableTTLSupport.scala | 116 +++++++++++++++ .../StatefulProcessorHandleImpl.scala | 31 +++- .../execution/streaming/TimerStateImpl.scala | 8 +- .../streaming/TransformWithStateExec.scala | 85 +++++++++-- .../execution/streaming/ValueStateImpl.scala | 94 ++++++++++--- .../execution/streaming/state/RocksDB.scala | 4 + .../state/RocksDBStateStoreProvider.scala | 7 + .../streaming/state/StateStore.scala | 5 + .../streaming/state/StateStoreErrors.scala | 14 +- .../streaming/state/ListStateSuite.scala | 16 ++- .../streaming/state/MapStateSuite.scala | 11 +- .../state/StatefulProcessorHandleSuite.scala | 15 +- .../streaming/state/ValueStateSuite.scala | 27 ++-- .../TransformWithListStateSuite.scala | 8 ++ .../TransformWithMapStateSuite.scala | 5 + .../streaming/TransformWithStateSuite.scala | 133 ++++++++++++++++++ 27 files changed, 655 insertions(+), 85 deletions(-) create mode 100644 sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 717d5e6631ec1..74ac04ed10a6a 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -4336,6 +4336,11 @@ "Removing column families with is not supported." ] }, + "STATE_STORE_TTL" : { + "message" : [ + "State TTL with is not supported. Please use RocksDBStateStoreProvider." + ] + }, "TABLE_OPERATION" : { "message" : [ "Table does not support . Please check the current catalog and namespace to make sure the qualified table name is expected, and also check the catalog implementation which is configured by \"spark.sql.catalog\"." diff --git a/docs/sql-error-conditions-unsupported-feature-error-class.md b/docs/sql-error-conditions-unsupported-feature-error-class.md index e580ecc63b188..f67d7caff63de 100644 --- a/docs/sql-error-conditions-unsupported-feature-error-class.md +++ b/docs/sql-error-conditions-unsupported-feature-error-class.md @@ -202,6 +202,10 @@ Creating multiple column families with `` is not supported. Removing column families with `` is not supported. +## STATE_STORE_TTL + +State TTL with `` is not supported. Please use RocksDBStateStoreProvider. + ## TABLE_OPERATION Table `` does not support ``. Please check the current catalog and namespace to make sure the qualified table name is expected, and also check the catalog implementation which is configured by "spark.sql.catalog". diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java new file mode 100644 index 0000000000000..210f4b78eb847 --- /dev/null +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.catalyst.plans.logical.*; + +/** + * Represents the type of ttl modes possible for user defined state + * in [[StatefulProcessor]]. + */ +@Experimental +@Evolving +public class TTLMode { + + /** + * Specifies that there is no TTL for the state object. Such objects would not + * be cleaned up by Spark automatically. + */ + public static final TTLMode NoTTL() { + return NoTTL$.MODULE$; + } + + /** + * Specifies that the specified ttl is in processing time. + */ + public static final TTLMode ProcessingTimeTTL() { + return ProcessingTimeTTL$.MODULE$; + } + + /** + * Specifies that the specified ttl is in event time. + */ + public static final TTLMode EventTimeTTL() { return EventTimeTTL$.MODULE$; } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala new file mode 100644 index 0000000000000..e0e02868fbf42 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.streaming.TTLMode + +/** Types of timeouts used in tranformWithState operator */ +case object NoTTL extends TTLMode + +case object ProcessingTimeTTL extends TTLMode + +case object EventTimeTTL extends TTLMode diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala index 560188a0ff621..10a914e112477 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala @@ -30,15 +30,20 @@ import org.apache.spark.sql.Encoder private[sql] trait StatefulProcessorHandle extends Serializable { /** - * Function to create new or return existing single value state variable of given type + * Function to create new or return existing single value state variable of given type. + * The state will be eventually cleaned up after the specified ttl. * The user must ensure to call this function only within the `init()` method of the * StatefulProcessor. - * @param stateName - name of the state variable + * + * @param stateName - name of the state variable * @param valEncoder - SQL encoder for state variable + * @param ttlMode - ttl mode for the state * @tparam T - type of state variable * @return - instance of ValueState of type T that can be used to store state persistently */ - def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] + def getValueState[T]( + stateName: String, + valEncoder: Encoder[T]): ValueState[T] /** * Creates new or returns existing list state associated with stateName. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala index 9c707c8308abf..36ee12fa83e11 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import java.io.Serializable +import java.time.Duration import org.apache.spark.annotation.{Evolving, Experimental} @@ -43,7 +44,8 @@ private[sql] trait ValueState[S] extends Serializable { def getOption(): Option[S] /** Update the value of the state. */ - def update(newState: S): Unit + // TODO(sahnib) confirm if this should be scala or Java type of Duration + def update(newState: S, ttlDuration: Duration = Duration.ZERO): Unit /** Remove this state. */ def clear(): Unit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index cb8673d20ed3d..75152ced43867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode, TTLMode} import org.apache.spark.sql.types._ object CatalystSerde { @@ -574,6 +574,7 @@ object TransformWithState { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[K, V, U], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, child: LogicalPlan): LogicalPlan = { @@ -584,6 +585,7 @@ object TransformWithState { groupingAttributes, dataAttributes, statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]], + ttlMode, timeoutMode, outputMode, keyEncoder.asInstanceOf[ExpressionEncoder[Any]], @@ -600,6 +602,7 @@ case class TransformWithState( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 50ab2a41612b4..331ff425e1b2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.internal.TypedAggUtils -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode, TTLMode} /** * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -662,6 +662,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode = OutputMode.Append()): Dataset[U] = { Dataset[U]( sparkSession, @@ -669,6 +670,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, child = logicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f77d0fef4eb95..5bedcaf3e6e01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -751,7 +751,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case TransformWithState( keyDeserializer, valueDeserializer, groupingAttributes, - dataAttributes, statefulProcessor, timeoutMode, outputMode, + dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputAttr, child) => val execPlan = TransformWithStateExec( keyDeserializer, @@ -759,6 +759,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, keyEncoder, @@ -917,10 +918,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { hasInitialState, planLater(initialState), planLater(child) ) :: Nil case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes, - dataAttributes, statefulProcessor, timeoutMode, outputMode, keyEncoder, + dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputObjAttr, child) => TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, valueDeserializer, - groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode, + groupingAttributes, dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputObjAttr, planLater(child)) :: Nil case _: FlatMapGroupsInPandasWithState => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index d2ccd0a778074..c58f32ed756db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -45,7 +45,7 @@ class MapStateImpl[K, V]( /** Whether state exists or not. */ override def exists(): Boolean = { - !store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).isEmpty + store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).nonEmpty } /** Get the state value if it exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index 1d41db896cdf2..f5509abfc52f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -23,11 +23,19 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.state.StateStoreErrors -import org.apache.spark.sql.types.{BinaryType, StructType} +import org.apache.spark.sql.types.{BinaryType, LongType, StructType} object StateKeyValueRowSchema { val KEY_ROW_SCHEMA: StructType = new StructType().add("key", BinaryType) - val VALUE_ROW_SCHEMA: StructType = new StructType().add("value", BinaryType) + val VALUE_ROW_SCHEMA: StructType = new StructType() + .add("value", BinaryType) + .add("ttlExpirationMs", LongType) + + def encodeGroupingKeyBytes(keyBytes: Array[Byte]): UnsafeRow = { + val keyProjection = UnsafeProjection.create(KEY_ROW_SCHEMA) + val keyRow = keyProjection(InternalRow(keyBytes)) + keyRow + } } /** @@ -65,30 +73,50 @@ class StateTypesEncoder[GK, V]( // TODO: validate places that are trying to encode the key and check if we can eliminate/ // add caching for some of these calls. def encodeGroupingKey(): UnsafeRow = { + val keyRow = keyProjection(InternalRow(serializeGroupingKey())) + keyRow + } + + def encodeSerializedGroupingKey( + groupingKeyBytes: Array[Byte]): UnsafeRow = { + val keyRow = keyProjection(InternalRow(groupingKeyBytes)) + keyRow + } + + def serializeGroupingKey(): Array[Byte] = { val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption if (keyOption.isEmpty) { throw StateStoreErrors.implicitKeyNotFound(stateName) } - val groupingKey = keyOption.get.asInstanceOf[GK] - val keyByteArr = keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() - val keyRow = keyProjection(InternalRow(keyByteArr)) - keyRow + keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() } def encodeValue(value: V): UnsafeRow = { val objRow: InternalRow = objToRowSerializer.apply(value) val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() - val valRow = valueProjection(InternalRow(bytes)) + val valRow = valueProjection(InternalRow(bytes, 0L)) valRow } + def encodeValue(value: V, expirationMs: Long = -1): UnsafeRow = { + val objRow: InternalRow = objToRowSerializer.apply(value) + val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() + val valueRow = valueProjection(InternalRow(bytes, expirationMs)) + valueRow + } + def decodeValue(row: UnsafeRow): V = { val bytes = row.getBinary(0) reusedValRow.pointTo(bytes, bytes.length) val value = rowToObjDeserializer.apply(reusedValRow) value } + + def decodeTtlExpirationMs(row: UnsafeRow): Long = { + val expirationMs = row.getLong(1) + expirationMs + } } object StateTypesEncoder { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala new file mode 100644 index 0000000000000..836fff647a2bf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.time.Duration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.streaming.TTLMode +import org.apache.spark.sql.types.{BinaryType, DataType, LongType, NullType, StructField, StructType} + +object StateTTLSchema { + val KEY_ROW_SCHEMA: StructType = new StructType() + .add("expirationMs", LongType) + .add("groupingKey", BinaryType) + val VALUE_ROW_SCHEMA: StructType = + StructType(Array(StructField("__dummy__", NullType))) +} + +trait StateVariableTTLSupport { + def clearIfExpired(groupingKey: Array[Byte]): Unit +} + +trait TTLState { + def clearExpiredState(): Unit +} + +class SingleKeyTTLState( + ttlMode: TTLMode, + stateName: String, + store: StateStore, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long], + state: StateVariableTTLSupport) + extends TTLState { + + import org.apache.spark.sql.execution.streaming.StateTTLSchema._ + + private val ttlColumnFamilyName = s"_ttl_$stateName" + private val ttlKeyEncoder = UnsafeProjection.create(KEY_ROW_SCHEMA) + + // empty row used for values + private val EMPTY_ROW = + UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) + + store.createColFamilyIfAbsent(ttlColumnFamilyName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), isInternal = true) + + def upsertTTLForStateKey( + expirationMs: Long, + groupingKey: Array[Byte]): Unit = { + val encodedTtlKey = ttlKeyEncoder(InternalRow(expirationMs, groupingKey)) + store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName) + } + + override def clearExpiredState(): Unit = { + store.iterator(ttlColumnFamilyName).foreach { kv => + val expirationMs = kv.key.getLong(0) + val isExpired = StateTTL.isExpired(ttlMode, expirationMs, + batchTimestampMs, eventTimeWatermarkMs) + + if (isExpired) { + val groupingKey = kv.key.getBinary(1) + state.clearIfExpired(groupingKey) + + // TODO(sahnib) ; validate its safe to update inside iterator + store.remove(kv.key, ttlColumnFamilyName) + } + } + } +} + +object StateTTL { + def calculateExpirationTimeForDuration( + ttlMode: TTLMode, + ttlDuration: Duration, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Long = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get + ttlDuration.toMillis + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get + ttlDuration.toMillis + } else { + -1L + } + } + + def isExpired( + ttlMode: TTLMode, + expirationMs: Long, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Boolean = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get > expirationMs + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get > expirationMs + } else { + false + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 9b905ad5235db..8aacd33fda995 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.streaming +import java.util import java.util.UUID import org.apache.spark.TaskContext @@ -24,7 +25,7 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TimeoutMode, TTLMode, ValueState} import org.apache.spark.util.Utils /** @@ -77,14 +78,19 @@ class StatefulProcessorHandleImpl( store: StateStore, runId: UUID, keyEncoder: ExpressionEncoder[Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, - isStreaming: Boolean = true) + isStreaming: Boolean = true, + batchTimestampMs: Option[Long] = None, + eventTimeWatermarkMs: Option[Long] = None) extends StatefulProcessorHandle with Logging { import StatefulProcessorHandleState._ + private val ttlStates: util.List[TTLState] = new util.ArrayList[TTLState]() + private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" - private def buildQueryInfo(): QueryInfo = { + private def buildQueryInfo(): QueryInfo = { val taskCtxOpt = Option(TaskContext.get()) val (queryId, batchId) = if (!isStreaming) { (BATCH_QUERY_ID, 0L) @@ -115,9 +121,16 @@ class StatefulProcessorHandleImpl( def getHandleState: StatefulProcessorHandleState = currState - override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { - verifyStateVarOperations("get_value_state") - val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) + override def getValueState[T]( + stateName: String, + valEncoder: Encoder[T]): ValueState[T] = { + verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " + + "initialization is complete") + + val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder, + ttlMode, batchTimestampMs, eventTimeWatermarkMs) + resultState.ttlState.foreach(ttlStates.add(_)) + resultState } @@ -183,6 +196,12 @@ class StatefulProcessorHandleImpl( timerState.listTimers() } + def doTtlCleanup(): Unit = { + ttlStates.forEach { s => + s.clearExpiredState() + } + } + /** * Function to delete and purge state variable if defined previously * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala index 6166374d25e94..170c7c205a8ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -78,25 +78,25 @@ class TimerStateImpl( private val secIndexKeyEncoder = UnsafeProjection.create(keySchemaForSecIndex) - val timerCFName = if (timeoutMode == TimeoutMode.ProcessingTime) { + private val timerCFName = if (timeoutMode == TimeoutMode.ProcessingTime) { TimerStateUtils.PROC_TIMERS_STATE_NAME } else { TimerStateUtils.EVENT_TIMERS_STATE_NAME } - val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF + private val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF store.createColFamilyIfAbsent(keyToTsCFName, schemaForKeyRow, schemaForValueRow, PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1), useMultipleValuesPerKey = false, isInternal = true) - val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF + private val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex, schemaForValueRow, NoPrefixKeyStateEncoderSpec(keySchemaForSecIndex), useMultipleValuesPerKey = false, isInternal = true) private def getGroupingKey(cfName: String): Any = { val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption - if (!keyOption.isDefined) { + if (keyOption.isEmpty) { throw StateStoreErrors.implicitKeyNotFound(cfName) } keyOption.get diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 39365e92185ad..45678d9a21b7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -17,8 +17,13 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID +import java.util.concurrent.ForkJoinPool import java.util.concurrent.TimeUnit.NANOSECONDS +import scala.util.control.NonFatal + +import org.apache.commons.lang3.SerializationUtils + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -26,10 +31,10 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, TimeoutMode} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, TimeoutMode, TTLMode} import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Utils} /** @@ -40,6 +45,7 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Uti * @param groupingAttributes used to group the data * @param dataAttributes used to read the data * @param statefulProcessor processor methods called on underlying data + * @param ttlMode defines the ttl Mode * @param timeoutMode defines the timeout mode * @param outputMode defines the output mode for the statefulProcessor * @param keyEncoder expression encoder for the key type @@ -56,6 +62,7 @@ case class TransformWithStateExec( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], @@ -90,16 +97,59 @@ case class TransformWithStateExec( override def keyExpressions: Seq[Attribute] = groupingAttributes - protected val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) - - protected val schemaForValueRow: StructType = new StructType().add("value", BinaryType) - override def requiredChildDistribution: Seq[Distribution] = { StatefulOperatorPartitioning.getCompatibleDistribution(groupingAttributes, getStateInfo, conf) :: Nil } + private def startTTLCleanupThread(store: StateStore): ForkJoinPool = { + // get state name from the statefulProcessor + val ttlColFamilies = store.listColumnFamilies().filter(_.startsWith("ttl_")) + val threadPool = new ForkJoinPool(ttlColFamilies.size) + @volatile var exception: Option[Throwable] = None + // start thread in fork join pool for each ttl column family + ttlColFamilies.foreach { ttlColFamily => + threadPool.execute(() => { + try { + ttlFunc(store, ttlColFamily) + } catch { + case NonFatal(e) => + exception = Some(e) + logError(s"Error in TTL thread for stateName=$ttlColFamily", e) + } + }) + } + threadPool + } + + def ttlFunc(store: StateStore, ttlColFamily: String): Unit = { + val expiredKeyStateNames = + store.iterator(ttlColFamily).flatMap { kv => + val ttl = kv.key.getLong(0) + if (ttl <= System.currentTimeMillis()) { + Some(kv.key) + } else { + None + } + } + expiredKeyStateNames.foreach { keyStateName => + store.remove(keyStateName, ttlColFamily) + val stateName = SerializationUtils.deserialize( + keyStateName.getBinary(1)).asInstanceOf[String] + val groupingKey = keyStateName.getBinary(2) + val keyRow = StateKeyValueRowSchema.encodeGroupingKeyBytes(groupingKey) + val row = store.get(keyRow, stateName) + if (row != null) { + val ttl = row.getLong(1) + if (ttl != -1 && ttl <= System.currentTimeMillis()) { + store.remove(keyRow, stateName) + } + } + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( groupingAttributes.map(SortOrder(_, Ascending))) @@ -241,6 +291,8 @@ case class TransformWithStateExec( allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) commitTimeMs += timeTakenMs { if (isStreaming) { + // join ttlBackgroundThread forkjoinpool + processorHandle.doTtlCleanup() store.commit() } else { store.abort() @@ -274,9 +326,9 @@ case class TransformWithStateExec( if (isStreaming) { child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, - schemaForKeyRow, - schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), session.sqlContext.sessionState, Some(session.sqlContext.streams.stateStoreCoordinator), useColumnFamilies = true, @@ -306,9 +358,9 @@ case class TransformWithStateExec( // Create StateStoreProvider for this partition val stateStoreProvider = StateStoreProvider.createAndInit( providerId, - schemaForKeyRow, - schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useColumnFamilies = true, storeConf = storeConf, hadoopConf = broadcastedHadoopConf.value, @@ -334,7 +386,8 @@ case class TransformWithStateExec( private def processData(store: StateStore, singleIterator: Iterator[InternalRow]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl( - store, getStateInfo.queryRunId, keyEncoder, timeoutMode, isStreaming) + store, getStateInfo.queryRunId, keyEncoder, ttlMode, timeoutMode, + isStreaming, batchTimestampMs, eventTimeWatermarkForEviction) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) statefulProcessor.init(outputMode, timeoutMode) @@ -343,6 +396,7 @@ case class TransformWithStateExec( } } +// scalastyle:off argcount object TransformWithStateExec { // Plan logical transformWithState for batch queries @@ -352,6 +406,7 @@ object TransformWithStateExec { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], @@ -372,6 +427,7 @@ object TransformWithStateExec { groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, keyEncoder, @@ -384,3 +440,6 @@ object TransformWithStateExec { isStreaming = false) } } + +// scalastyle:on argcount + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 08876ca3032ee..b09dd43c30b85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -16,39 +16,59 @@ */ package org.apache.spark.sql.execution.streaming +import java.time.Duration + import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} -import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.streaming.{TTLMode, ValueState} /** * Class that provides a concrete implementation for a single value state associated with state * variables used in the streaming transformWithState operator. * @param store - reference to the StateStore instance to be used for storing state * @param stateName - name of logical state partition - * @param keyEnc - Spark SQL encoder for key + * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value + * @param ttlMode ttl Mode to evict expired values from state + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermarkMs event time watermark for state eviction * @tparam S - data type of object that will be stored */ class ValueStateImpl[S]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], - valEncoder: Encoder[S]) extends ValueState[S] with Logging { + valEncoder: Encoder[S], + ttlMode: TTLMode, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]) + extends ValueState[S] + with Logging + with StateVariableTTLSupport { private val keySerializer = keyExprEnc.createSerializer() - private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) + private[sql] var ttlState: Option[SingleKeyTTLState] = None + + initialize() - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + private def initialize(): Unit = { + store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + + if (ttlMode != TTLMode.NoTTL()) { + val _ttlState = new SingleKeyTTLState(ttlMode, stateName, store, + batchTimestampMs, eventTimeWatermarkMs, this) + ttlState = Some(_ttlState) + } + } /** Function to check if state exists. Returns true if present and false otherwise */ override def exists(): Boolean = { - getImpl() != null + get() != null } /** Function to return Option of value if exists and None otherwise */ @@ -58,26 +78,66 @@ class ValueStateImpl[S]( /** Function to return associated value with key if exists and null otherwise */ override def get(): S = { - val retRow = getImpl() + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + if (retRow != null) { - stateTypesEncoder.decodeValue(retRow) + val resState = stateTypesEncoder.decodeValue(retRow) + + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(retRow) + val isExpired = StateTTL.isExpired(ttlMode, + expirationMs, batchTimestampMs, eventTimeWatermarkMs) + + if (!isExpired) { + resState + } else { + null.asInstanceOf[S] + } } else { null.asInstanceOf[S] } } - private def getImpl(): UnsafeRow = { - store.get(stateTypesEncoder.encodeGroupingKey(), stateName) - } - /** Function to update and overwrite state associated with given key */ - override def update(newState: S): Unit = { - store.put(stateTypesEncoder.encodeGroupingKey(), - stateTypesEncoder.encodeValue(newState), stateName) + override def update( + newState: S, + ttlDuration: Duration = Duration.ZERO): Unit = { + + if (ttlDuration != Duration.ZERO && ttlState.isEmpty) { + // TODO(sahnib) throw a StateStoreError here + throw new RuntimeException() + } + + var expirationMs: Long = -1 + if (ttlDuration != Duration.ZERO) { + expirationMs = StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } + + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey), + stateTypesEncoder.encodeValue(newState, expirationMs), stateName) + + ttlState.foreach(_.upsertTTLForStateKey(expirationMs, serializedGroupingKey)) } /** Function to remove state for given key */ override def clear(): Unit = { store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) } + + def clearIfExpired(groupingKey: Array[Byte]): Unit = { + val encodedGroupingKey = stateTypesEncoder.encodeSerializedGroupingKey(groupingKey) + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(retRow) + val isExpired = StateTTL.isExpired(ttlMode, + expirationMs, batchTimestampMs, eventTimeWatermarkMs) + + if (!isExpired) { + store.remove(encodedGroupingKey, stateName) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 1817104a5c223..1c1159f654546 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -342,6 +342,10 @@ class RocksDB( } } + def listColumnFamilies(): Seq[String] = { + Seq() + } + /** * Remove RocksDB column family, if exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index c8537f2a6a5b1..c6db9b5d000c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -257,6 +257,13 @@ private[sql] class RocksDBStateStoreProvider rocksDB.removeColFamilyIfExists(colFamilyName) keyValueEncoderMap.remove(colFamilyName) } + + /** Return a list of column family names */ + override def listColumnFamilies(): Seq[String] = { + verify(useColumnFamilies, "Column families are not supported in this store") + // turn keys of keyValueEncoderMap to Seq + rocksDB.listColumnFamilies() + } } override def init( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index dd97aa5b9afca..c79c244dcf443 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -122,6 +122,11 @@ trait StateStore extends ReadStateStore { */ def removeColFamilyIfExists(colFamilyName: String): Unit + /** + * Return list of column family names. + */ + def listColumnFamilies(): Seq[String] = Seq(StateStore.DEFAULT_COL_FAMILY_NAME) + /** * Create column family with given name, if absent. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index a8d4c06bc83c4..d451f1d6db16f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -49,6 +49,11 @@ object StateStoreErrors { new StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider) } + def ttlNotSupportedWithProvider(stateStoreProvider: String): + StateStoreTTLNotSupportedException = { + new StateStoreTTLNotSupportedException(stateStoreProvider) + } + def removingColumnFamiliesNotSupported(stateStoreProvider: String): StateStoreRemovingColumnFamiliesNotSupportedException = { new StateStoreRemovingColumnFamiliesNotSupportedException(stateStoreProvider) @@ -117,7 +122,14 @@ object StateStoreErrors { class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_MULTIPLE_COLUMN_FAMILIES", - messageParameters = Map("stateStoreProvider" -> stateStoreProvider)) + messageParameters = Map("stateStoreProvider" -> stateStoreProvider) + ) + +class StateStoreTTLNotSupportedException(stateStoreProvider: String) + extends SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_TTL", + messageParameters = Map("stateStoreProvider" -> stateStoreProvider) + ) class StateStoreRemovingColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala index e895e475b74d9..51cfc1548b398 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} -import org.apache.spark.sql.streaming.{ListState, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, TimeoutMode, TTLMode, ValueState} /** * Class that adds unit tests for ListState types used in arbitrary stateful @@ -37,7 +37,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val listState: ListState[Long] = handle.getListState[Long]("listState", Encoders.scalaLong) @@ -47,7 +48,7 @@ class ListStateSuite extends StateVariableSuiteBase { } checkError( - exception = e.asInstanceOf[SparkIllegalArgumentException], + exception = e, errorClass = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE", sqlState = Some("42601"), parameters = Map("stateName" -> "listState") @@ -70,7 +71,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ListState[Long] = handle.getListState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") @@ -98,7 +100,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: ListState[Long] = handle.getListState[Long]("testState1", Encoders.scalaLong) val testState2: ListState[Long] = handle.getListState[Long]("testState2", Encoders.scalaLong) @@ -136,7 +139,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val listState1: ListState[Long] = handle.getListState[Long]("listState1", Encoders.scalaLong) val listState2: ListState[Long] = handle.getListState[Long]("listState2", Encoders.scalaLong) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala index ce72061d39ea2..7fa41b12795eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala @@ -22,7 +22,7 @@ import java.util.UUID import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} -import org.apache.spark.sql.streaming.{ListState, MapState, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, MapState, TimeoutMode, TTLMode, ValueState} import org.apache.spark.sql.types.{BinaryType, StructType} /** @@ -39,7 +39,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: MapState[String, Double] = handle.getMapState[String, Double]("testState", Encoders.STRING, Encoders.scalaDouble) @@ -73,7 +74,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: MapState[Long, Double] = handle.getMapState[Long, Double]("testState1", Encoders.scalaLong, Encoders.scalaDouble) @@ -112,7 +114,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val mapTestState1: MapState[String, Int] = handle.getMapState[String, Int]("mapTestState1", Encoders.STRING, Encoders.scalaInt) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala index 662a5dbfaac4f..aec828459fce5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} -import org.apache.spark.sql.streaming.TimeoutMode +import org.apache.spark.sql.streaming.{TimeoutMode, TTLMode} + /** * Class that adds tests to verify operations based on stateful processor handle @@ -48,7 +49,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) assert(handle.getHandleState === StatefulProcessorHandleState.CREATED) handle.getValueState[Long]("testState", Encoders.scalaLong) } @@ -89,7 +90,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) Seq(StatefulProcessorHandleState.INITIALIZED, StatefulProcessorHandleState.DATA_PROCESSED, @@ -107,7 +108,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, TimeoutMode.NoTimeouts()) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val ex = intercept[SparkUnsupportedOperationException] { handle.registerTimer(10000L) } @@ -143,7 +144,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) handle.setHandleState(StatefulProcessorHandleState.INITIALIZED) assert(handle.getHandleState === StatefulProcessorHandleState.INITIALIZED) @@ -164,7 +165,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) handle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED) assert(handle.getHandleState === StatefulProcessorHandleState.DATA_PROCESSED) @@ -204,7 +205,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) Seq(StatefulProcessorHandleState.CREATED, StatefulProcessorHandleState.TIMER_PROCESSED, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 8668b58672c7e..8063c2cdb155e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{TimeoutMode, TTLMode, ValueState} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -48,7 +48,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val stateName = "testState" val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -78,7 +79,7 @@ class ValueStateSuite extends StateVariableSuiteBase { testState.update(123) } checkError( - ex.asInstanceOf[SparkException], + ex1.asInstanceOf[SparkException], errorClass = "INTERNAL_ERROR_TWS", parameters = Map( "message" -> s"Implicit key not found in state store for stateName=$stateName" @@ -92,7 +93,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") @@ -118,7 +120,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: ValueState[Long] = handle.getValueState[Long]( "testState1", Encoders.scalaLong) @@ -164,7 +167,7 @@ class ValueStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], - TimeoutMode.NoTimeouts()) + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val cfName = "_testState" val ex = intercept[SparkUnsupportedOperationException] { @@ -204,7 +207,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Double] = handle.getValueState[Double]("testState", Encoders.scalaDouble) @@ -230,7 +234,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -256,7 +261,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[TestClass] = handle.getValueState[TestClass]("testState", Encoders.product[TestClass]) @@ -282,7 +288,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[POJOTestClass] = handle.getValueState[POJOTestClass]("testState", Encoders.bean(classOf[POJOTestClass])) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 95ab34d401311..51cc9ff87890c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -140,6 +140,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update()) ( @@ -160,6 +161,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -180,6 +182,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -200,6 +203,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -220,6 +224,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -240,6 +245,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -260,6 +266,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update()) ( @@ -312,6 +319,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x) .transformWithState(new ToggleSaveAndEmitProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index d7c5ce3815b04..46ed7e3fb3f5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -94,6 +94,7 @@ class TransformWithMapStateSuite extends StreamTest { .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) @@ -120,6 +121,7 @@ class TransformWithMapStateSuite extends StreamTest { .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -144,6 +146,7 @@ class TransformWithMapStateSuite extends StreamTest { .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -167,6 +170,7 @@ class TransformWithMapStateSuite extends StreamTest { .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) testStream(result, OutputMode.Append())( // Test exists() @@ -221,6 +225,7 @@ class TransformWithMapStateSuite extends StreamTest { .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) val df = result.toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 24b0d59c45c56..1514a43fe39b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -34,6 +34,35 @@ object TransformWithStateSuiteUtils { val NUM_SHUFFLE_PARTITIONS = 5 } +class RunningCountStatefulProcessorZeroTTL + extends StatefulProcessor[String, String, (String, String)] + with Logging { + @transient private var _countState: ValueState[Long] = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode) : Unit = { + assert(getHandle.getQueryInfo().getBatchId >= 0) + _countState = getHandle.getValueState("countState", Encoders.scalaLong) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + val count = _countState.getOption().getOrElse(0L) + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(count) + Iterator((key, count.toString)) + } + } + + override def close(): Unit = {} +} class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)] with Logging { @transient protected var _countState: ValueState[Long] = _ @@ -231,6 +260,7 @@ class RunningCountMostRecentStatefulProcessor _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) _mostRecent = getHandle.getValueState[String]("mostRecent", Encoders.STRING) } + override def handleInputRows( key: String, inputRows: Iterator[(String, String)], @@ -249,6 +279,41 @@ class RunningCountMostRecentStatefulProcessor } } +class RunningCountMostRecentStatefulProcessorWithTTL + extends StatefulProcessor[String, (String, String), (String, String, String)] + with Logging { + @transient private var _countState: ValueState[Long] = _ + @transient private var _mostRecent: ValueState[String] = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode) : Unit = { + assert(getHandle.getQueryInfo().getBatchId >= 0) + _countState = getHandle.getValueState( + "countState", Encoders.scalaLong) + _mostRecent = getHandle.getValueState("mostRecent", Encoders.STRING) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, String)], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, String)] = { + val count = _countState.getOption().getOrElse(0L) + 1 + val mostRecent = _mostRecent.getOption().getOrElse("") + + var output = List[(String, String, String)]() + inputRows.foreach { row => + _mostRecent.update(row._2) + _countState.update(count) + output = (key, count.toString, mostRecent) :: output + } + output.iterator + } + + override def close(): Unit = {} +} + class MostRecentStatefulProcessorWithDeletion extends StatefulProcessor[String, (String, String), (String, String)] with Logging { @@ -310,6 +375,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessorWithError(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -331,6 +397,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -361,6 +428,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -404,6 +472,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new RunningCountStatefulProcessorWithProcTimeTimerUpdates(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -440,6 +509,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new RunningCountStatefulProcessorWithMultipleTimers(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -475,6 +545,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new MaxEventTimeStatefulProcessor(), TimeoutMode.EventTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -516,6 +587,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) val df = result.toDF() @@ -534,12 +606,14 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x._1) .transformWithState(new RunningCountMostRecentStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) val stream2 = inputData.toDS() .groupByKey(x => x._1) .transformWithState(new MostRecentStatefulProcessorWithDeletion(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(stream1, OutputMode.Update())( @@ -559,6 +633,61 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + test("transformWithState - Zero duration TTL, should expire immediately") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[String] + + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorZeroTTL(), + TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), + OutputMode.Update()) + + // State should expire immediately, meaning each answer is independent + // of previous counts + testStream(result, OutputMode.Update())( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "1"), ("b", "1")), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "1"), ("b", "1")), + StopStream + ) + } + } + + test("transformWithState - multiple state variables with one TTL") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[(String, String)] + val stream1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessorWithTTL(), + TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), + OutputMode.Update()) + + // State should expire immediately, meaning each answer is independent + // of previous counts + testStream(stream1, OutputMode.Update())( + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + AddData(inputData, ("a", "str2"), ("b", "str3")), + CheckNewAnswer(("a", "1", "str1"), ("b", "1", "")), + AddData(inputData, ("a", "str4"), ("b", "str5")), + CheckNewAnswer(("a", "1", "str2"), ("b", "1", "str3")), + StopStream + ) + } + } + test("transformWithState - two input streams") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -572,6 +701,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -605,6 +735,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -638,6 +769,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -760,6 +892,7 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest { .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( From 1bc16f6219d0542021b77a2f6d843680a43f3da5 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Fri, 22 Mar 2024 13:09:12 -0700 Subject: [PATCH 02/16] Fix existing testcases, and add TTL testcase. --- .../streaming/StateVariableTTLSupport.scala | 13 +- .../StatefulProcessorHandleImpl.scala | 11 +- .../streaming/TransformWithStateExec.scala | 1 + .../execution/streaming/ValueStateImpl.scala | 2 +- .../streaming/TransformWithStateSuite.scala | 119 ------------------ .../TransformWithStateTTLSuite.scala | 89 +++++++++++++ 6 files changed, 109 insertions(+), 126 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala index 836fff647a2bf..07fa8c141b15d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming import java.time.Duration +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} @@ -45,14 +46,15 @@ class SingleKeyTTLState( stateName: String, store: StateStore, batchTimestampMs: Option[Long], - eventTimeWatermarkMs: Option[Long], - state: StateVariableTTLSupport) - extends TTLState { + eventTimeWatermarkMs: Option[Long]) + extends TTLState + with Logging { import org.apache.spark.sql.execution.streaming.StateTTLSchema._ private val ttlColumnFamilyName = s"_ttl_$stateName" private val ttlKeyEncoder = UnsafeProjection.create(KEY_ROW_SCHEMA) + private var state: StateVariableTTLSupport = _ // empty row used for values private val EMPTY_ROW = @@ -83,6 +85,11 @@ class SingleKeyTTLState( } } } + + private[sql] def setStateVariable( + state: StateVariableTTLSupport): Unit = { + this.state = state + } } object StateTTL { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 8aacd33fda995..39de2fde5e57a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -89,6 +89,7 @@ class StatefulProcessorHandleImpl( private val ttlStates: util.List[TTLState] = new util.ArrayList[TTLState]() private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" + logInfo(s"Created StatefulProcessorHandle") private def buildQueryInfo(): QueryInfo = { val taskCtxOpt = Option(TaskContext.get()) @@ -124,12 +125,16 @@ class StatefulProcessorHandleImpl( override def getValueState[T]( stateName: String, valEncoder: Encoder[T]): ValueState[T] = { - verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " + - "initialization is complete") + verifyStateVarOperations("get_value_state") val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder, ttlMode, batchTimestampMs, eventTimeWatermarkMs) - resultState.ttlState.foreach(ttlStates.add(_)) + val ttlState = resultState.ttlState + + ttlState.foreach { s => + ttlStates.add(s) + s.setStateVariable(resultState) + } resultState } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 45678d9a21b7a..6285f0d34978b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -309,6 +309,7 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver + // TODO(sahnib) add validation for ttlMode timeoutMode match { case ProcessingTime => if (batchTimestampMs.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index b09dd43c30b85..7cfc873ef3a5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -61,7 +61,7 @@ class ValueStateImpl[S]( if (ttlMode != TTLMode.NoTTL()) { val _ttlState = new SingleKeyTTLState(ttlMode, stateName, store, - batchTimestampMs, eventTimeWatermarkMs, this) + batchTimestampMs, eventTimeWatermarkMs) ttlState = Some(_ttlState) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 1514a43fe39b4..0b79378ecc079 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -34,35 +34,6 @@ object TransformWithStateSuiteUtils { val NUM_SHUFFLE_PARTITIONS = 5 } -class RunningCountStatefulProcessorZeroTTL - extends StatefulProcessor[String, String, (String, String)] - with Logging { - @transient private var _countState: ValueState[Long] = _ - - override def init( - outputMode: OutputMode, - timeoutMode: TimeoutMode) : Unit = { - assert(getHandle.getQueryInfo().getBatchId >= 0) - _countState = getHandle.getValueState("countState", Encoders.scalaLong) - } - - override def handleInputRows( - key: String, - inputRows: Iterator[String], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { - val count = _countState.getOption().getOrElse(0L) + 1 - if (count == 3) { - _countState.clear() - Iterator.empty - } else { - _countState.update(count) - Iterator((key, count.toString)) - } - } - - override def close(): Unit = {} -} class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)] with Logging { @transient protected var _countState: ValueState[Long] = _ @@ -279,41 +250,6 @@ class RunningCountMostRecentStatefulProcessor } } -class RunningCountMostRecentStatefulProcessorWithTTL - extends StatefulProcessor[String, (String, String), (String, String, String)] - with Logging { - @transient private var _countState: ValueState[Long] = _ - @transient private var _mostRecent: ValueState[String] = _ - - override def init( - outputMode: OutputMode, - timeoutMode: TimeoutMode) : Unit = { - assert(getHandle.getQueryInfo().getBatchId >= 0) - _countState = getHandle.getValueState( - "countState", Encoders.scalaLong) - _mostRecent = getHandle.getValueState("mostRecent", Encoders.STRING) - } - - override def handleInputRows( - key: String, - inputRows: Iterator[(String, String)], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, String)] = { - val count = _countState.getOption().getOrElse(0L) + 1 - val mostRecent = _mostRecent.getOption().getOrElse("") - - var output = List[(String, String, String)]() - inputRows.foreach { row => - _mostRecent.update(row._2) - _countState.update(count) - output = (key, count.toString, mostRecent) :: output - } - output.iterator - } - - override def close(): Unit = {} -} - class MostRecentStatefulProcessorWithDeletion extends StatefulProcessor[String, (String, String), (String, String)] with Logging { @@ -633,61 +569,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - Zero duration TTL, should expire immediately") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - val inputData = MemoryStream[String] - - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorZeroTTL(), - TimeoutMode.NoTimeouts(), - TTLMode.NoTTL(), - OutputMode.Update()) - - // State should expire immediately, meaning each answer is independent - // of previous counts - testStream(result, OutputMode.Update())( - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), - AddData(inputData, "a", "b"), - CheckNewAnswer(("a", "1"), ("b", "1")), - AddData(inputData, "a", "b"), - CheckNewAnswer(("a", "1"), ("b", "1")), - StopStream - ) - } - } - - test("transformWithState - multiple state variables with one TTL") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - val inputData = MemoryStream[(String, String)] - val stream1 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new RunningCountMostRecentStatefulProcessorWithTTL(), - TimeoutMode.NoTimeouts(), - TTLMode.NoTTL(), - OutputMode.Update()) - - // State should expire immediately, meaning each answer is independent - // of previous counts - testStream(stream1, OutputMode.Update())( - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "1", "")), - AddData(inputData, ("a", "str2"), ("b", "str3")), - CheckNewAnswer(("a", "1", "str1"), ("b", "1", "")), - AddData(inputData, ("a", "str4"), ("b", "str5")), - CheckNewAnswer(("a", "1", "str2"), ("b", "1", "str3")), - StopStream - ) - } - } - test("transformWithState - two input streams") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala new file mode 100644 index 0000000000000..a854781c5e2c9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.time.Duration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StreamManualClock + +class ValueStateTTLProcessor + extends StatefulProcessor[String, String, (String, Long)] + with Logging { + + @transient private var _countState: ValueState[Long] = _ + + override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { + _countState = getHandle.getValueState("countState", Encoders.scalaLong) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Long)] = { + + val currValueOption = _countState.getOption() + + var totalLogins: Long = inputRows.size + if (currValueOption.isDefined) { + totalLogins = totalLogins + currValueOption.get + } + + _countState.update(totalLogins, Duration.ofMinutes(1)) + + Iterator.single((key, totalLogins)) + } +} + +class TransformWithStateTTLSuite + extends StreamTest { + import testImplicits._ + + test("validate state is evicted at ttl expiry") { + + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[String] + val result = inputStream.toDS() + .groupByKey(x => x) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, "k1"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("k1", 1L)), + // advance clock so that state expires + AdvanceManualClock(60 * 1000), + AddData(inputStream, "k1"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("k1", 1)) + ) + } + } +} From 8c8c6b41366427b978f688ad86f25b830a858856 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Fri, 22 Mar 2024 16:44:07 -0700 Subject: [PATCH 03/16] Improved documentation, added NERF error classes for user errors. --- .../main/resources/error/error-classes.json | 6 + .../apache/spark/sql/streaming/TTLMode.java | 16 +-- .../sql/catalyst/plans/logical/TTLMode.scala | 2 +- .../streaming/StatefulProcessorHandle.scala | 6 +- .../streaming/StateTypesEncoderUtils.scala | 32 +++++- ...cala => StateVariableWithTTLSupport.scala} | 49 ++++++++- .../StatefulProcessorHandleImpl.scala | 10 +- .../streaming/TransformWithStateExec.scala | 103 ++++++------------ .../execution/streaming/ValueStateImpl.scala | 33 +++--- .../execution/streaming/state/RocksDB.scala | 4 - .../state/RocksDBStateStoreProvider.scala | 7 -- .../streaming/state/StateStore.scala | 5 - .../streaming/state/StateStoreErrors.scala | 33 +++--- .../TransformWithStateTTLSuite.scala | 64 +++++++---- 14 files changed, 210 insertions(+), 160 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{StateVariableTTLSupport.scala => StateVariableWithTTLSupport.scala} (69%) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 74ac04ed10a6a..b81849716b1ae 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3530,6 +3530,12 @@ ], "sqlState" : "0A000" }, + "STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE" : { + "message" : [ + "State store operation= on state= does not support TTL in NoTTL() mode." + ], + "sqlState" : "42802" + }, "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE" : { "message" : [ "Failed to perform stateful processor operation= with invalid handle state=." diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java index 210f4b78eb847..016407db5e101 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java @@ -22,30 +22,30 @@ import org.apache.spark.sql.catalyst.plans.logical.*; /** - * Represents the type of ttl modes possible for user defined state - * in [[StatefulProcessor]]. + * Represents the type of ttl modes possible for the Dataset operations + * {@code transformWithState}. */ @Experimental @Evolving public class TTLMode { /** - * Specifies that there is no TTL for the state object. Such objects would not + * Specifies that there is no TTL for the user state. User state would not * be cleaned up by Spark automatically. */ - public static final TTLMode NoTTL() { + public static TTLMode NoTTL() { return NoTTL$.MODULE$; } /** - * Specifies that the specified ttl is in processing time. + * Specifies that all ttl durations for user state are in processing time. */ - public static final TTLMode ProcessingTimeTTL() { + public static TTLMode ProcessingTimeTTL() { return ProcessingTimeTTL$.MODULE$; } /** - * Specifies that the specified ttl is in event time. + * Specifies that all ttl durations for user state are in event time. */ - public static final TTLMode EventTimeTTL() { return EventTimeTTL$.MODULE$; } + public static TTLMode EventTimeTTL() { return EventTimeTTL$.MODULE$; } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala index e0e02868fbf42..4b43060e0d358 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.streaming.TTLMode -/** Types of timeouts used in tranformWithState operator */ +/** TTL types used in tranformWithState operator */ case object NoTTL extends TTLMode case object ProcessingTimeTTL extends TTLMode diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala index 10a914e112477..30f2d9000ecc0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala @@ -31,19 +31,15 @@ private[sql] trait StatefulProcessorHandle extends Serializable { /** * Function to create new or return existing single value state variable of given type. - * The state will be eventually cleaned up after the specified ttl. * The user must ensure to call this function only within the `init()` method of the * StatefulProcessor. * * @param stateName - name of the state variable * @param valEncoder - SQL encoder for state variable - * @param ttlMode - ttl mode for the state * @tparam T - type of state variable * @return - instance of ValueState of type T that can be used to store state persistently */ - def getValueState[T]( - stateName: String, - valEncoder: Encoder[T]): ValueState[T] + def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] /** * Creates new or returns existing list state associated with stateName. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index f5509abfc52f9..f97b91ccc9606 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -70,6 +70,8 @@ class StateTypesEncoder[GK, V]( private val rowToObjDeserializer = valExpressionEnc.resolveAndBind().createDeserializer() private val reusedValRow = new UnsafeRow(valEncoder.schema.fields.length) + private val NO_TTL_ENCODED_VALUE: Long = -1L + // TODO: validate places that are trying to encode the key and check if we can eliminate/ // add caching for some of these calls. def encodeGroupingKey(): UnsafeRow = { @@ -77,6 +79,12 @@ class StateTypesEncoder[GK, V]( keyRow } + /** + * Encodes the provided grouping key into Spark UnsafeRow. + * + * @param groupingKeyBytes serialized grouping key byte array + * @return encoded UnsafeRow + */ def encodeSerializedGroupingKey( groupingKeyBytes: Array[Byte]): UnsafeRow = { val keyRow = keyProjection(InternalRow(groupingKeyBytes)) @@ -92,13 +100,21 @@ class StateTypesEncoder[GK, V]( keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() } + /** + * Encode the specified value in Spark UnsafeRow with no ttl. + * The ttl expiration will be set to -1, specifying no TTL. + */ def encodeValue(value: V): UnsafeRow = { val objRow: InternalRow = objToRowSerializer.apply(value) val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() - val valRow = valueProjection(InternalRow(bytes, 0L)) + val valRow = valueProjection(InternalRow(bytes, NO_TTL_ENCODED_VALUE)) valRow } + /** + * Encode the specified value in Spark UnsafeRow + * with provided ttl expiration. + */ def encodeValue(value: V, expirationMs: Long = -1): UnsafeRow = { val objRow: InternalRow = objToRowSerializer.apply(value) val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() @@ -113,9 +129,19 @@ class StateTypesEncoder[GK, V]( value } - def decodeTtlExpirationMs(row: UnsafeRow): Long = { + /** + * Decode the ttl information out of Value row. If the ttl has + * not been set (-1L specifies no user defined value), the API will + * return None. + */ + def decodeTtlExpirationMs(row: UnsafeRow): Option[Long] = { val expirationMs = row.getLong(1) - expirationMs + + if (expirationMs == NO_TTL_ENCODED_VALUE) { + None + } else { + Some(expirationMs) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala similarity index 69% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala index 07fa8c141b15d..bfddb38f17bd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala @@ -33,14 +33,48 @@ object StateTTLSchema { StructType(Array(StructField("__dummy__", NullType))) } -trait StateVariableTTLSupport { +/** + * Represents a State variable which supports TTL. + */ +trait StateVariableWithTTLSupport { + + /** + * Clears the user state associated with this grouping key + * if it has expired. This function is called by Spark to perform + * cleanup at the end of transformWithState processing. + * + * Spark uses a secondary index to determine if the user state for + * this grouping key has expired. However, its possible that the user + * has updated the TTL and secondary index is out of date. Implementations + * must validate that the user State has actually expired before cleanup based + * on their own State data. + * + * @param groupingKey grouping key for which cleanup should be performed. + */ def clearIfExpired(groupingKey: Array[Byte]): Unit } +/** + * Represents the underlying state for secondary TTL Index for a user defined + * state variable. + * + * This state allows Spark to query ttl values based on expiration time + * allowing efficient ttl cleanup. + */ trait TTLState { + + /** + * Perform the user state clean yp based on ttl values stored in + * this state. NOTE that its not safe to call this operation concurrently + * when the user can also modify the underlying State. Cleanup should be initiated + * after arbitrary state operations are completed by the user. + */ def clearExpiredState(): Unit } +/** + * Manages the ttl information for user state keyed with a single key (grouping key). + */ class SingleKeyTTLState( ttlMode: TTLMode, stateName: String, @@ -54,7 +88,7 @@ class SingleKeyTTLState( private val ttlColumnFamilyName = s"_ttl_$stateName" private val ttlKeyEncoder = UnsafeProjection.create(KEY_ROW_SCHEMA) - private var state: StateVariableTTLSupport = _ + private var state: StateVariableWithTTLSupport = _ // empty row used for values private val EMPTY_ROW = @@ -87,11 +121,14 @@ class SingleKeyTTLState( } private[sql] def setStateVariable( - state: StateVariableTTLSupport): Unit = { + state: StateVariableWithTTLSupport): Unit = { this.state = state } } +/** + * Helper methods for user State TTL. + */ object StateTTL { def calculateExpirationTimeForDuration( ttlMode: TTLMode, @@ -103,7 +140,8 @@ object StateTTL { } else if (ttlMode == TTLMode.EventTimeTTL()) { eventTimeWatermarkMs.get + ttlDuration.toMillis } else { - -1L + throw new IllegalStateException(s"cannot calculate expiration time for" + + s" unknown ttl Mode $ttlMode") } } @@ -117,7 +155,8 @@ object StateTTL { } else if (ttlMode == TTLMode.EventTimeTTL()) { eventTimeWatermarkMs.get > expirationMs } else { - false + throw new IllegalStateException(s"cannot evaluate expiry condition for" + + s" unknown ttl Mode $ttlMode") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 39de2fde5e57a..b50e5507a4d1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -110,12 +110,6 @@ class StatefulProcessorHandleImpl( private var currState: StatefulProcessorHandleState = CREATED - private def verify(condition: => Boolean, msg: String): Unit = { - if (!condition) { - throw new IllegalStateException(msg) - } - } - def setHandleState(newState: StatefulProcessorHandleState): Unit = { currState = newState } @@ -201,6 +195,10 @@ class StatefulProcessorHandleImpl( timerState.listTimers() } + /** + * Performs the user state cleanup based on assigned TTl values. Any state + * which is expired will be cleaned up from StateStore. + */ def doTtlCleanup(): Unit = { ttlStates.forEach { s => s.clearExpiredState() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 6285f0d34978b..3a1b31422dbc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -17,13 +17,8 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID -import java.util.concurrent.ForkJoinPool import java.util.concurrent.TimeUnit.NANOSECONDS -import scala.util.control.NonFatal - -import org.apache.commons.lang3.SerializationUtils - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -45,7 +40,7 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Uti * @param groupingAttributes used to group the data * @param dataAttributes used to read the data * @param statefulProcessor processor methods called on underlying data - * @param ttlMode defines the ttl Mode + * @param ttlMode defines the ttl Mode for user state * @param timeoutMode defines the timeout mode * @param outputMode defines the output mode for the statefulProcessor * @param keyEncoder expression encoder for the key type @@ -103,52 +98,6 @@ case class TransformWithStateExec( Nil } - private def startTTLCleanupThread(store: StateStore): ForkJoinPool = { - // get state name from the statefulProcessor - val ttlColFamilies = store.listColumnFamilies().filter(_.startsWith("ttl_")) - val threadPool = new ForkJoinPool(ttlColFamilies.size) - @volatile var exception: Option[Throwable] = None - // start thread in fork join pool for each ttl column family - ttlColFamilies.foreach { ttlColFamily => - threadPool.execute(() => { - try { - ttlFunc(store, ttlColFamily) - } catch { - case NonFatal(e) => - exception = Some(e) - logError(s"Error in TTL thread for stateName=$ttlColFamily", e) - } - }) - } - threadPool - } - - def ttlFunc(store: StateStore, ttlColFamily: String): Unit = { - val expiredKeyStateNames = - store.iterator(ttlColFamily).flatMap { kv => - val ttl = kv.key.getLong(0) - if (ttl <= System.currentTimeMillis()) { - Some(kv.key) - } else { - None - } - } - expiredKeyStateNames.foreach { keyStateName => - store.remove(keyStateName, ttlColFamily) - val stateName = SerializationUtils.deserialize( - keyStateName.getBinary(1)).asInstanceOf[String] - val groupingKey = keyStateName.getBinary(2) - val keyRow = StateKeyValueRowSchema.encodeGroupingKeyBytes(groupingKey) - val row = store.get(keyRow, stateName) - if (row != null) { - val ttl = row.getLong(1) - if (ttl != -1 && ttl <= System.currentTimeMillis()) { - store.remove(keyRow, stateName) - } - } - } - } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( groupingAttributes.map(SortOrder(_, Ascending))) @@ -291,7 +240,7 @@ case class TransformWithStateExec( allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) commitTimeMs += timeTakenMs { if (isStreaming) { - // join ttlBackgroundThread forkjoinpool + // clean up any expired user state processorHandle.doTtlCleanup() store.commit() } else { @@ -309,20 +258,8 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver - // TODO(sahnib) add validation for ttlMode - timeoutMode match { - case ProcessingTime => - if (batchTimestampMs.isEmpty) { - StateStoreErrors.missingTimeoutValues(timeoutMode.toString) - } - - case EventTime => - if (eventTimeWatermarkForEviction.isEmpty) { - StateStoreErrors.missingTimeoutValues(timeoutMode.toString) - } - - case _ => - } + validateTTLMode() + validateTimeoutMode() if (isStreaming) { child.execute().mapPartitionsWithStateStore[InternalRow]( @@ -395,6 +332,38 @@ case class TransformWithStateExec( processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) processDataWithPartition(singleIterator, store, processorHandle) } + + private def validateTimeoutMode(): Unit = { + timeoutMode match { + case ProcessingTime => + if (batchTimestampMs.isEmpty) { + StateStoreErrors.missingTimeoutValues(timeoutMode.toString) + } + + case EventTime => + if (eventTimeWatermarkForEviction.isEmpty) { + StateStoreErrors.missingTimeoutValues(timeoutMode.toString) + } + + case _ => + } + } + + private def validateTTLMode(): Unit = { + ttlMode match { + case ProcessingTimeTTL => + if (batchTimestampMs.isEmpty) { + StateStoreErrors.missingTTLValues(timeoutMode.toString) + } + + case EventTimeTTL => + if (eventTimeWatermarkForEviction.isEmpty) { + StateStoreErrors.missingTTLValues(timeoutMode.toString) + } + + case _ => + } + } } // scalastyle:off argcount diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 7cfc873ef3a5a..11f4c8e1f461d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -21,8 +21,9 @@ import java.time.Duration import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{TTLMode, ValueState} /** @@ -32,9 +33,10 @@ import org.apache.spark.sql.streaming.{TTLMode, ValueState} * @param stateName - name of logical state partition * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value - * @param ttlMode ttl Mode to evict expired values from state + * @param ttlMode - TTL Mode for values stored in this state * @param batchTimestampMs processing timestamp of the current batch. - * @param eventTimeWatermarkMs event time watermark for state eviction + * @param eventTimeWatermarkMs event time watermark for streaming query + * (same as watermark for state eviction) * @tparam S - data type of object that will be stored */ class ValueStateImpl[S]( @@ -47,7 +49,7 @@ class ValueStateImpl[S]( eventTimeWatermarkMs: Option[Long]) extends ValueState[S] with Logging - with StateVariableTTLSupport { + with StateVariableWithTTLSupport { private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) @@ -84,11 +86,7 @@ class ValueStateImpl[S]( if (retRow != null) { val resState = stateTypesEncoder.decodeValue(retRow) - val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(retRow) - val isExpired = StateTTL.isExpired(ttlMode, - expirationMs, batchTimestampMs, eventTimeWatermarkMs) - - if (!isExpired) { + if (!isExpired(retRow)) { resState } else { null.asInstanceOf[S] @@ -104,8 +102,7 @@ class ValueStateImpl[S]( ttlDuration: Duration = Duration.ZERO): Unit = { if (ttlDuration != Duration.ZERO && ttlState.isEmpty) { - // TODO(sahnib) throw a StateStoreError here - throw new RuntimeException() + throw StateStoreErrors.cannotProvideTTLDurationForNoTTLMode("update", stateName) } var expirationMs: Long = -1 @@ -131,13 +128,17 @@ class ValueStateImpl[S]( val retRow = store.get(encodedGroupingKey, stateName) if (retRow != null) { - val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(retRow) - val isExpired = StateTTL.isExpired(ttlMode, - expirationMs, batchTimestampMs, eventTimeWatermarkMs) - - if (!isExpired) { + if (isExpired(retRow)) { store.remove(encodedGroupingKey, stateName) } } } + + private def isExpired(valueRow: UnsafeRow): Boolean = { + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow) + val isExpired = expirationMs.map( + StateTTL.isExpired(ttlMode, _, batchTimestampMs, eventTimeWatermarkMs)) + + isExpired.isDefined && isExpired.get + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 1c1159f654546..1817104a5c223 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -342,10 +342,6 @@ class RocksDB( } } - def listColumnFamilies(): Seq[String] = { - Seq() - } - /** * Remove RocksDB column family, if exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index c6db9b5d000c7..c8537f2a6a5b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -257,13 +257,6 @@ private[sql] class RocksDBStateStoreProvider rocksDB.removeColFamilyIfExists(colFamilyName) keyValueEncoderMap.remove(colFamilyName) } - - /** Return a list of column family names */ - override def listColumnFamilies(): Seq[String] = { - verify(useColumnFamilies, "Column families are not supported in this store") - // turn keys of keyValueEncoderMap to Seq - rocksDB.listColumnFamilies() - } } override def init( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index c79c244dcf443..dd97aa5b9afca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -122,11 +122,6 @@ trait StateStore extends ReadStateStore { */ def removeColFamilyIfExists(colFamilyName: String): Unit - /** - * Return list of column family names. - */ - def listColumnFamilies(): Seq[String] = Seq(StateStore.DEFAULT_COL_FAMILY_NAME) - /** * Create column family with given name, if absent. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index d451f1d6db16f..8e779a2d32914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -39,6 +39,13 @@ object StateStoreErrors { ) } + def missingTTLValues(ttlMode: String): SparkException = { + SparkException.internalError( + msg = s"Failed to find timeout values for ttlMode=$ttlMode", + category = "TWS" + ) + } + def unsupportedOperationOnMissingColumnFamily(operationName: String, colFamilyName: String): StateStoreUnsupportedOperationOnMissingColumnFamily = { new StateStoreUnsupportedOperationOnMissingColumnFamily(operationName, colFamilyName) @@ -49,11 +56,6 @@ object StateStoreErrors { new StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider) } - def ttlNotSupportedWithProvider(stateStoreProvider: String): - StateStoreTTLNotSupportedException = { - new StateStoreTTLNotSupportedException(stateStoreProvider) - } - def removingColumnFamiliesNotSupported(stateStoreProvider: String): StateStoreRemovingColumnFamiliesNotSupportedException = { new StateStoreRemovingColumnFamiliesNotSupportedException(stateStoreProvider) @@ -117,19 +119,17 @@ object StateStoreErrors { handleState: String): StatefulProcessorCannotPerformOperationWithInvalidHandleState = { new StatefulProcessorCannotPerformOperationWithInvalidHandleState(operationType, handleState) } + + def cannotProvideTTLDurationForNoTTLMode(operationType: String, + stateName: String): StatefulProcessorCannotAssignTTLInNoTTLMode = { + new StatefulProcessorCannotAssignTTLInNoTTLMode(operationType, stateName) + } } class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_MULTIPLE_COLUMN_FAMILIES", - messageParameters = Map("stateStoreProvider" -> stateStoreProvider) - ) - -class StateStoreTTLNotSupportedException(stateStoreProvider: String) - extends SparkUnsupportedOperationException( - errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_TTL", - messageParameters = Map("stateStoreProvider" -> stateStoreProvider) - ) + messageParameters = Map("stateStoreProvider" -> stateStoreProvider)) class StateStoreRemovingColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( @@ -194,3 +194,10 @@ class StateStoreNullTypeOrderingColsNotSupported(fieldName: String, index: Strin extends SparkUnsupportedOperationException( errorClass = "STATE_STORE_NULL_TYPE_ORDERING_COLS_NOT_SUPPORTED", messageParameters = Map("fieldName" -> fieldName, "index" -> index)) + +class StatefulProcessorCannotAssignTTLInNoTTLMode( + operationType: String, + stateName: String) + extends SparkUnsupportedOperationException( + errorClass = "STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE", + messageParameters = Map("operationType" -> operationType, "stateName" -> stateName)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala index a854781c5e2c9..189398b0319b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala @@ -22,36 +22,53 @@ import java.time.Duration import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.execution.streaming.state.{RocksDB, RocksDBConf, RocksDBStateStoreProvider, StateStoreConf} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock +case class InputEvent( + key: String, + action: String, + value: Int, + ttl: Duration) + +case class OutputEvent( + key: String, + exists: Boolean, + value: Int) + class ValueStateTTLProcessor - extends StatefulProcessor[String, String, (String, Long)] + extends StatefulProcessor[String, InputEvent, OutputEvent] with Logging { - @transient private var _countState: ValueState[Long] = _ + @transient private var _valueState: ValueState[Int] = _ override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { - _countState = getHandle.getValueState("countState", Encoders.scalaLong) + _valueState = getHandle.getValueState("valueState", Encoders.scalaInt) } override def handleInputRows( key: String, - inputRows: Iterator[String], + inputRows: Iterator[InputEvent], timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Long)] = { + expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + var results = List[OutputEvent]() - val currValueOption = _countState.getOption() + for (row <- inputRows) { + if (row.action == "get") { + val currState = _valueState.getOption() - var totalLogins: Long = inputRows.size - if (currValueOption.isDefined) { - totalLogins = totalLogins + currValueOption.get + if (currState.isDefined) { + results = OutputEvent(key, exists = true, currState.get) :: results + } else { + results = OutputEvent(key, exists = false, -1) :: results + } + } else if (row.action == "put") { + _valueState.update(row.value, row.ttl) + } } - _countState.update(totalLogins, Duration.ofMinutes(1)) - - Iterator.single((key, totalLogins)) + results.iterator } } @@ -60,13 +77,12 @@ class TransformWithStateTTLSuite import testImplicits._ test("validate state is evicted at ttl expiry") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - val inputStream = MemoryStream[String] + val inputStream = MemoryStream[InputEvent] val result = inputStream.toDS() - .groupByKey(x => x) + .groupByKey(x => x.key) .transformWithState( new ValueStateTTLProcessor(), TimeoutMode.NoTimeouts(), @@ -75,14 +91,22 @@ class TransformWithStateTTLSuite val clock = new StreamManualClock testStream(result)( StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), - AddData(inputStream, "k1"), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state + AddData(inputStream, InputEvent("k1", "get", 1, null)), + // advance clock to trigger processing AdvanceManualClock(1 * 1000), - CheckNewAnswer(("k1", 1L)), + CheckNewAnswer(OutputEvent("k1", exists = true, 1)), // advance clock so that state expires AdvanceManualClock(60 * 1000), - AddData(inputStream, "k1"), + AddData(inputStream, InputEvent("k1", "get", -1, null)), AdvanceManualClock(1 * 1000), - CheckNewAnswer(("k1", 1)) + // validate state does not exist anymore + CheckNewAnswer(OutputEvent("k1", exists = false, -1)) + // ensure this state does not exist any longer in State ) } } From 3a88c579d7c48e8c9e2392b876dd6b6f6420d9ca Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Fri, 22 Mar 2024 16:48:40 -0700 Subject: [PATCH 04/16] Regenerate docs for SQL Error conditions. --- docs/sql-error-conditions.md | 6 ++++++ .../spark/sql/streaming/TransformWithStateTTLSuite.scala | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index b05a8d1ff61eb..e499192a50e4f 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -2162,6 +2162,12 @@ The SQL config `` cannot be found. Please verify that the config exists Star (*) is not allowed in a select list when GROUP BY an ordinal position is used. +### STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +State store operation=`` on state=`` does not support TTL in NoTTL() mode. + ### STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE [SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala index 189398b0319b1..f74862b21735d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala @@ -22,7 +22,7 @@ import java.time.Duration import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{RocksDB, RocksDBConf, RocksDBStateStoreProvider, StateStoreConf} +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock From 0abd3ac7a146e38c9eebb22ef18e4be1d1896f8c Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Mon, 25 Mar 2024 18:04:23 -0700 Subject: [PATCH 05/16] Add more testcases for value state ttl. --- .../apache/spark/sql/streaming/TTLMode.java | 6 +- .../spark/sql/streaming/ValueState.scala | 8 +- .../StateVariableWithTTLSupport.scala | 30 +- .../streaming/TransformWithStateExec.scala | 1 + .../execution/streaming/ValueStateImpl.scala | 75 ++- .../TransformWithStateTTLSuite.scala | 461 +++++++++++++++++- 6 files changed, 548 insertions(+), 33 deletions(-) diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java index 016407db5e101..b18a4b136828b 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java @@ -33,19 +33,19 @@ public class TTLMode { * Specifies that there is no TTL for the user state. User state would not * be cleaned up by Spark automatically. */ - public static TTLMode NoTTL() { + public static final TTLMode NoTTL() { return NoTTL$.MODULE$; } /** * Specifies that all ttl durations for user state are in processing time. */ - public static TTLMode ProcessingTimeTTL() { + public static final TTLMode ProcessingTimeTTL() { return ProcessingTimeTTL$.MODULE$; } /** * Specifies that all ttl durations for user state are in event time. */ - public static TTLMode EventTimeTTL() { return EventTimeTTL$.MODULE$; } + public static final TTLMode EventTimeTTL() { return EventTimeTTL$.MODULE$; } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala index 36ee12fa83e11..11e69dd08bf18 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala @@ -43,8 +43,12 @@ private[sql] trait ValueState[S] extends Serializable { /** Get the state if it exists as an option and None otherwise */ def getOption(): Option[S] - /** Update the value of the state. */ - // TODO(sahnib) confirm if this should be scala or Java type of Duration + /** + * Update the value of the state. + * @param newState the new value + * @param ttlDuration set the ttl to current batch processing time (for processing time TTL mode) + * or current watermark (for event time ttl mode) plus ttlDuration + */ def update(newState: S, ttlDuration: Duration = Duration.ZERO): Unit /** Remove this state. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala index bfddb38f17bd9..c46b03a9dcaee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala @@ -33,6 +33,15 @@ object StateTTLSchema { StructType(Array(StructField("__dummy__", NullType))) } +/** + * Encapsulates the ttl row information stored in [[SingleKeyTTLState]]. + * @param groupingKey grouping key for which ttl is set + * @param expirationMs expiration time for the grouping key + */ +case class SingleKeyTTLRow( + groupingKey: Array[Byte], + expirationMs: Long) + /** * Represents a State variable which supports TTL. */ @@ -114,7 +123,6 @@ class SingleKeyTTLState( val groupingKey = kv.key.getBinary(1) state.clearIfExpired(groupingKey) - // TODO(sahnib) ; validate its safe to update inside iterator store.remove(kv.key, ttlColumnFamilyName) } } @@ -124,6 +132,22 @@ class SingleKeyTTLState( state: StateVariableWithTTLSupport): Unit = { this.state = state } + + private[sql] def iterator(): Iterator[SingleKeyTTLRow] = { + val ttlIterator = store.iterator(ttlColumnFamilyName) + + new Iterator[SingleKeyTTLRow] { + override def hasNext: Boolean = ttlIterator.hasNext + + override def next(): SingleKeyTTLRow = { + val kv = ttlIterator.next() + SingleKeyTTLRow( + expirationMs = kv.key.getLong(0), + groupingKey = kv.key.getBinary(1) + ) + } + } + } } /** @@ -151,9 +175,9 @@ object StateTTL { batchTimestampMs: Option[Long], eventTimeWatermarkMs: Option[Long]): Boolean = { if (ttlMode == TTLMode.ProcessingTimeTTL()) { - batchTimestampMs.get > expirationMs + batchTimestampMs.get >= expirationMs } else if (ttlMode == TTLMode.EventTimeTTL()) { - eventTimeWatermarkMs.get > expirationMs + eventTimeWatermarkMs.get >= expirationMs } else { throw new IllegalStateException(s"cannot evaluate expiry condition for" + s" unknown ttl Mode $ttlMode") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 3a1b31422dbc9..4c8b2a02c8a93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -350,6 +350,7 @@ case class TransformWithStateExec( } private def validateTTLMode(): Unit = { + logWarning(s"Validating ttl Mode - $ttlMode $eventTimeWatermarkForEviction") ttlMode match { case ProcessingTimeTTL => if (batchTimestampMs.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 11f4c8e1f461d..8d9eaeb39f919 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -106,7 +106,7 @@ class ValueStateImpl[S]( } var expirationMs: Long = -1 - if (ttlDuration != Duration.ZERO) { + if (ttlDuration != null && ttlDuration != Duration.ZERO) { expirationMs = StateTTL.calculateExpirationTimeForDuration( ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) } @@ -141,4 +141,77 @@ class ValueStateImpl[S]( isExpired.isDefined && isExpired.get } + + /* + * Internal methods to probe state for testing. The below methods exist for unit tests + * to read the state ttl values, and ensure that values are persisted correctly in + * the underlying state store. + */ + + /** + * Retrieves the value from State even if its expired. This method is used + * in tests to read the state store value, and ensure if its cleaned up at the + * end of the micro-batch. + */ + private[sql] def getWithoutEnforcingTTL(): Option[S] = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + val resState = stateTypesEncoder.decodeValue(retRow) + Some(resState) + } else { + None + } + } + + /** + * Read the ttl value associated with the grouping key. + */ + private[sql] def getTTLValue(): Option[Long] = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + stateTypesEncoder.decodeTtlExpirationMs(retRow) + } else { + None + } + } + + /** + * Get all ttl values stored in ttl state for current implicit + * grouping key. + */ + private[sql] def getValuesInTTLState(): Iterator[Long] = { + if (ttlState.isEmpty) { + Iterator.empty + } + + val ttlIterator = ttlState.get.iterator() + val implicitGroupingKey = stateTypesEncoder.serializeGroupingKey() + var nextValue: Option[Long] = None + + new Iterator[Long] { + override def hasNext: Boolean = { + while (nextValue.isEmpty && ttlIterator.hasNext) { + val nextTtlValue = ttlIterator.next() + val groupingKey = nextTtlValue.groupingKey + + if (groupingKey sameElements implicitGroupingKey) { + nextValue = Some(nextTtlValue.expirationMs) + } + } + + nextValue.isDefined + } + + override def next(): Long = { + val result = nextValue.get + nextValue = None + + result + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala index f74862b21735d..a1acf3689d0ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.streaming -import java.time.Duration +import java.sql.Timestamp +import java.time.{Duration, Instant} +import java.time.temporal.ChronoUnit import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders -import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.{MemoryStream, ValueStateImpl} import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -30,21 +32,59 @@ case class InputEvent( key: String, action: String, value: Int, - ttl: Duration) + ttl: Duration, + eventTime: Timestamp = null) case class OutputEvent( key: String, - exists: Boolean, - value: Int) + value: Int, + isTTLValue: Boolean, + ttlValue: Long) + +object TTLInputProcessFunction { + def processRow( + row: InputEvent, + valueState: ValueStateImpl[Int]): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + val key = row.key + if (row.action == "get") { + val currState = valueState.getOption() + if (currState.isDefined) { + results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_without_enforcing_ttl") { + val currState = valueState.getWithoutEnforcingTTL() + if (currState.isDefined) { + results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_ttl_value_from_state") { + val ttlExpiration = valueState.getTTLValue() + if (ttlExpiration.isDefined) { + results = OutputEvent(key, -1, isTTLValue = true, ttlExpiration.get) :: results + } + } else if (row.action == "put") { + valueState.update(row.value, row.ttl) + } else if (row.action == "get_values_in_ttl_state") { + val ttlValues = valueState.getValuesInTTLState() + ttlValues.foreach { v => + results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: results + } + } + + results.iterator + } +} class ValueStateTTLProcessor extends StatefulProcessor[String, InputEvent, OutputEvent] with Logging { - @transient private var _valueState: ValueState[Int] = _ + @transient private var _valueState: ValueStateImpl[Int] = _ override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { - _valueState = getHandle.getValueState("valueState", Encoders.scalaInt) + _valueState = getHandle + .getValueState("valueState", Encoders.scalaInt) + .asInstanceOf[ValueStateImpl[Int]] } override def handleInputRows( @@ -55,16 +95,9 @@ class ValueStateTTLProcessor var results = List[OutputEvent]() for (row <- inputRows) { - if (row.action == "get") { - val currState = _valueState.getOption() - - if (currState.isDefined) { - results = OutputEvent(key, exists = true, currState.get) :: results - } else { - results = OutputEvent(key, exists = false, -1) :: results - } - } else if (row.action == "put") { - _valueState.update(row.value, row.ttl) + val resultIter = TTLInputProcessFunction.processRow(row, _valueState) + resultIter.foreach { r => + results = r :: results } } @@ -72,11 +105,51 @@ class ValueStateTTLProcessor } } +case class MultipleValueStatesTTLProcessor( + ttlKey: String, + noTtlKey: String) + extends StatefulProcessor[String, InputEvent, OutputEvent] + with Logging { + + @transient private var _valueStateWithTTL: ValueStateImpl[Int] = _ + @transient private var _valueStateWithoutTTL: ValueStateImpl[Int] = _ + + override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { + _valueStateWithTTL = getHandle + .getValueState("valueState", Encoders.scalaInt) + .asInstanceOf[ValueStateImpl[Int]] + _valueStateWithoutTTL = getHandle + .getValueState("valueState", Encoders.scalaInt) + .asInstanceOf[ValueStateImpl[Int]] + } + + override def handleInputRows( + key: String, + inputRows: Iterator[InputEvent], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + val state = if (key == ttlKey) { + _valueStateWithTTL + } else { + _valueStateWithoutTTL + } + + for (row <- inputRows) { + val resultIterator = TTLInputProcessFunction.processRow(row, state) + resultIterator.foreach { r => + results = r :: results + } + } + results.iterator + } +} + class TransformWithStateTTLSuite extends StreamTest { import testImplicits._ - test("validate state is evicted at ttl expiry") { + test("validate state is evicted at ttl expiry - processing time ttl") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -95,18 +168,358 @@ class TransformWithStateTTLSuite // advance clock to trigger processing AdvanceManualClock(1 * 1000), CheckNewAnswer(), - // get this state - AddData(inputStream, InputEvent("k1", "get", 1, null)), - // advance clock to trigger processing + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), AdvanceManualClock(1 * 1000), - CheckNewAnswer(OutputEvent("k1", exists = true, 1)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), // advance clock so that state expires AdvanceManualClock(60 * 1000), AddData(inputStream, InputEvent("k1", "get", -1, null)), AdvanceManualClock(1 * 1000), - // validate state does not exist anymore - CheckNewAnswer(OutputEvent("k1", exists = false, -1)) + // validate expired value is not returned + CheckNewAnswer(), + // ensure this state does not exist any longer in State + AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + test("validate ttl update updates the expiration timestamp - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock and update expiration time + AdvanceManualClock(30 * 1000), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent("k1", "get", -1, null)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + // validate value is not expired + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate ttl value is updated in the state + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)), + // validate ttl state has both ttl values present + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000), + OutputEvent("k1", -1, isTTLValue = true, 95000) + ), + // advance clock after older expiration value + AdvanceManualClock(30 * 1000), + // ensure unexpired value is still present in the state + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate that the older expiration value is removed from ttl state + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)) + ) + } + } + + test("validate ttl removal keeps value in state - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock and update state to remove ttl + AdvanceManualClock(30 * 1000), + AddData(inputStream, InputEvent("k1", "put", 1, null)), + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + // validate value is not expired + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // validate ttl state still has old ttl value present + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock after older expiration value + AdvanceManualClock(30 * 1000), + // ensure unexpired value is still present in the state + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate that the older expiration value is removed from ttl state + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + test("validate multiple value states - with and without ttl - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val ttlKey = "k1" + val noTtlKey = "k2" + + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + MultipleValueStatesTTLProcessor(ttlKey, noTtlKey), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent(ttlKey, "put", 1, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent(noTtlKey, "put", 2, Duration.ZERO)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get both state values, and make sure we get unexpired value + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer( + OutputEvent(ttlKey, 1, isTTLValue = false, -1), + OutputEvent(noTtlKey, 2, isTTLValue = false, -1) + ), + // ensure ttl values were added correctly, and noTtlKey has no ttl values + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent(ttlKey, "get_values_in_ttl_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + // advance clock after expiry + AdvanceManualClock(60 * 1000), + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + // validate ttlKey is expired, bot noTtlKey is still present + CheckNewAnswer(OutputEvent(noTtlKey, 2, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + test("validate state is evicted at ttl expiry - event time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .withWatermark("eventTime", "1 second") + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.EventTimeTTL()) + + val eventTime1 = Timestamp.from(Instant.EPOCH.plus(30, ChronoUnit.SECONDS)) + val eventTime2 = Timestamp.from(Instant.EPOCH.plus(60, ChronoUnit.SECONDS)) + val ttlDuration = Duration.ofMinutes(1) + val ttlExpirationMs = Instant.EPOCH.plus(ttlDuration).toEpochMilli + val eventTime3 = Timestamp.from(Instant.ofEpochMilli(ttlExpirationMs + 1000)) + + testStream(result)( + AddData(inputStream, + InputEvent("k1", "put", 1, ttlDuration, eventTime1)), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, + InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // increment event time so that key k1 expires + AddData(inputStream, InputEvent("k2", "put", 1, ttlDuration, eventTime3)), + CheckNewAnswer(), + // validate that k1 has expired + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime3)), + CheckNewAnswer(), // ensure this state does not exist any longer in State + AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", -1, null, eventTime3)), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime3)), + CheckNewAnswer() + ) + } + } + + test("validate ttl update updates the expiration timestamp - event time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .withWatermark("eventTime", "1 second") + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.EventTimeTTL()) + + val eventTime1 = Timestamp.from(Instant.EPOCH.plus(30, ChronoUnit.SECONDS)) + val eventTime2 = Timestamp.from(Instant.EPOCH.plus(60, ChronoUnit.SECONDS)) + val ttlDuration = Duration.ofMinutes(1) + val ttlExpirationMs = Instant.EPOCH.plus(ttlDuration).toEpochMilli + val eventTime3 = Timestamp.from(Instant.ofEpochMilli(ttlExpirationMs + 1000)) + + testStream(result)( + AddData(inputStream, + InputEvent("k1", "put", 1, ttlDuration, eventTime1)), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", 1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, + InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // remove tll expiration for key k1, and move watermark past previous ttl value + AddData(inputStream, InputEvent("k1", "put", 2, null, eventTime3)), + CheckNewAnswer(), + // validate that the key still exists + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime3)), + CheckNewAnswer(OutputEvent("k1", 2, isTTLValue = false, -1)), + // ensure this ttl expiration time has been removed from state + AddData(inputStream, + InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + ProcessAllAvailable() + ) + } + } + + test("validate ttl removal keeps value in state - event time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .withWatermark("eventTime", "1 second") + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.EventTimeTTL()) + + val eventTime1 = Timestamp.from(Instant.EPOCH.plus(30, ChronoUnit.SECONDS)) + val eventTime2 = Timestamp.from(Instant.EPOCH.plus(60, ChronoUnit.SECONDS)) + val ttlDuration = Duration.ofMinutes(1) + val ttlExpirationMs = Instant.EPOCH.plus(ttlDuration).toEpochMilli + val eventTime3 = Timestamp.from(Instant.ofEpochMilli(ttlExpirationMs + 1000)) + + testStream(result)( + AddData(inputStream, InputEvent("k1", "put", 1, ttlDuration, eventTime1)), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", 1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // update state and remove ttl + AddData(inputStream, InputEvent("k1", "put", 2, null, eventTime2)), + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime2)), + // validate value is not expired + CheckNewAnswer(OutputEvent("k1", 2, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(), + // validate ttl state still has old ttl value present + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime3)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // eventTime has been advanced to eventTim3 which is after older expiration value + // ensure unexpired value is still present in the state + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime3)), + CheckNewAnswer(OutputEvent("k1", 2, isTTLValue = false, -1)), + // validate that the older expiration value is removed from ttl state + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + CheckNewAnswer() ) } } From 71709fbc2eeccb90ef8d61ea3ccf49c0b1278095 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Mon, 25 Mar 2024 21:23:29 -0700 Subject: [PATCH 06/16] Fix indentation. --- .../src/main/java/org/apache/spark/sql/streaming/TTLMode.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java index b18a4b136828b..06af92dc13210 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java @@ -40,9 +40,7 @@ public static final TTLMode NoTTL() { /** * Specifies that all ttl durations for user state are in processing time. */ - public static final TTLMode ProcessingTimeTTL() { - return ProcessingTimeTTL$.MODULE$; - } + public static final TTLMode ProcessingTimeTTL() { return ProcessingTimeTTL$.MODULE$; } /** * Specifies that all ttl durations for user state are in event time. From 2bb25df0a43165cc4dfeb1ac3cf5fe62a7ba54bd Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Tue, 26 Mar 2024 08:49:31 -0700 Subject: [PATCH 07/16] Add a placeholder TODO to use range scan once available. --- .../sql/execution/streaming/StateVariableWithTTLSupport.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala index c46b03a9dcaee..286b926577fbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala @@ -103,6 +103,7 @@ class SingleKeyTTLState( private val EMPTY_ROW = UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) + // TODO: use range scan once Range Scan PR is merged for StateStore store.createColFamilyIfAbsent(ttlColumnFamilyName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), isInternal = true) From 65f3737dcc4cd095711364b78df1c8ec7f1555d8 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Tue, 26 Mar 2024 11:58:28 -0700 Subject: [PATCH 08/16] Remove unnecessary log in TransformWithStateExec. --- .../spark/sql/execution/streaming/TransformWithStateExec.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 4c8b2a02c8a93..3a1b31422dbc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -350,7 +350,6 @@ case class TransformWithStateExec( } private def validateTTLMode(): Unit = { - logWarning(s"Validating ttl Mode - $ttlMode $eventTimeWatermarkForEviction") ttlMode match { case ProcessingTimeTTL => if (batchTimestampMs.isEmpty) { From 382810159a50cb63c7dc044a34c6ec87f646256e Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Tue, 26 Mar 2024 13:16:13 -0700 Subject: [PATCH 09/16] Rebase with latest master --- .../org/apache/spark/sql/streaming/TransformWithStateSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 0b79378ecc079..79422d6fa83ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -681,6 +681,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) } From e917f7d7a6e6be9480cc20f1edd3a1c7335d801b Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Tue, 26 Mar 2024 14:22:34 -0700 Subject: [PATCH 10/16] Suppress method checkstyle for TTLMode. --- dev/checkstyle-suppressions.xml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 7b20dfb6bce58..94dfe20af56e7 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -60,6 +60,8 @@ files="sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java"/> + Date: Tue, 26 Mar 2024 19:54:15 -0700 Subject: [PATCH 11/16] Modify Spark Connect tws API to include ttlMode. --- .../scala/org/apache/spark/sql/KeyValueGroupedDataset.scala | 5 ++++- .../scala/org/apache/spark/sql/KeyValueGroupedDataset.scala | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index e5dd67ccf7874..5c0f4a7e8c81a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.ScalarUserDefinedFunction import org.apache.spark.sql.functions.col -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode, TTLMode} /** * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -830,12 +830,15 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { * Instance of statefulProcessor whose functions will be invoked by the operator. * @param timeoutMode * The timeout mode of the stateful processor. + * @param ttlMode + * The ttlMode to evict user state on ttl expiration. * @param outputMode * The output mode of the stateful processor. Defaults to APPEND mode. */ def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode = OutputMode.Append()): Dataset[U] = { throw new UnsupportedOperationException } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 331ff425e1b2b..d62381f2b823e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -656,6 +656,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the * operator. * @param timeoutMode The timeout mode of the stateful processor. + * @param ttlMode The ttlMode to evict user state on ttl expiration * @param outputMode The output mode of the stateful processor. Defaults to APPEND mode. * */ From af2b430149dfbc0d3b192a3804cece8cd04582d4 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Sun, 17 Mar 2024 21:28:07 -0700 Subject: [PATCH 12/16] Add support for ValueState TTL. --- .../streaming/StateVariableTTLSupport.scala | 116 +++++++++++++++++ .../execution/streaming/state/RocksDB.scala | 4 + .../state/RocksDBStateStoreProvider.scala | 7 ++ .../streaming/state/StateStore.scala | 5 + .../streaming/state/StateStoreErrors.scala | 14 ++- .../streaming/TransformWithStateSuite.scala | 119 ++++++++++++++++++ 6 files changed, 264 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala new file mode 100644 index 0000000000000..01c3b8b0257f4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.time.Duration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.TTLMode +import org.apache.spark.sql.types.{BinaryType, DataType, LongType, NullType, StructField, StructType} + +object StateTTLSchema { + val KEY_ROW_SCHEMA: StructType = new StructType() + .add("expirationMs", LongType) + .add("groupingKey", BinaryType) + val VALUE_ROW_SCHEMA: StructType = + StructType(Array(StructField("__dummy__", NullType))) +} + +trait StateVariableTTLSupport { + def clearIfExpired(groupingKey: Array[Byte]): Unit +} + +trait TTLState { + def clearExpiredState(): Unit +} + +class SingleKeyTTLState( + ttlMode: TTLMode, + stateName: String, + store: StateStore, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long], + state: StateVariableTTLSupport) + extends TTLState { + + import org.apache.spark.sql.execution.streaming.StateTTLSchema._ + + private val ttlColumnFamilyName = s"_ttl_$stateName" + private val ttlKeyEncoder = UnsafeProjection.create(KEY_ROW_SCHEMA) + + // empty row used for values + private val EMPTY_ROW = + UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) + + store.createColFamilyIfAbsent(ttlColumnFamilyName, KEY_ROW_SCHEMA, 0, + VALUE_ROW_SCHEMA, isInternal = true) + + def upsertTTLForStateKey( + expirationMs: Long, + groupingKey: Array[Byte]): Unit = { + val encodedTtlKey = ttlKeyEncoder(InternalRow(expirationMs, groupingKey)) + store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName) + } + + override def clearExpiredState(): Unit = { + store.iterator(ttlColumnFamilyName).foreach { kv => + val expirationMs = kv.key.getLong(0) + val isExpired = StateTTL.isExpired(ttlMode, expirationMs, + batchTimestampMs, eventTimeWatermarkMs) + + if (isExpired) { + val groupingKey = kv.key.getBinary(1) + state.clearIfExpired(groupingKey) + + // TODO(sahnib) ; validate its safe to update inside iterator + store.remove(kv.key, ttlColumnFamilyName) + } + } + } +} + +object StateTTL { + def calculateExpirationTimeForDuration( + ttlMode: TTLMode, + ttlDuration: Duration, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Long = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get + ttlDuration.toMillis + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get + ttlDuration.toMillis + } else { + -1L + } + } + + def isExpired( + ttlMode: TTLMode, + expirationMs: Long, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Boolean = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get > expirationMs + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get > expirationMs + } else { + false + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 1817104a5c223..1c1159f654546 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -342,6 +342,10 @@ class RocksDB( } } + def listColumnFamilies(): Seq[String] = { + Seq() + } + /** * Remove RocksDB column family, if exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index c8537f2a6a5b1..c6db9b5d000c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -257,6 +257,13 @@ private[sql] class RocksDBStateStoreProvider rocksDB.removeColFamilyIfExists(colFamilyName) keyValueEncoderMap.remove(colFamilyName) } + + /** Return a list of column family names */ + override def listColumnFamilies(): Seq[String] = { + verify(useColumnFamilies, "Column families are not supported in this store") + // turn keys of keyValueEncoderMap to Seq + rocksDB.listColumnFamilies() + } } override def init( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index dd97aa5b9afca..c79c244dcf443 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -122,6 +122,11 @@ trait StateStore extends ReadStateStore { */ def removeColFamilyIfExists(colFamilyName: String): Unit + /** + * Return list of column family names. + */ + def listColumnFamilies(): Seq[String] = Seq(StateStore.DEFAULT_COL_FAMILY_NAME) + /** * Create column family with given name, if absent. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 8e779a2d32914..cad31b088e96c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -56,6 +56,11 @@ object StateStoreErrors { new StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider) } + def ttlNotSupportedWithProvider(stateStoreProvider: String): + StateStoreTTLNotSupportedException = { + new StateStoreTTLNotSupportedException(stateStoreProvider) + } + def removingColumnFamiliesNotSupported(stateStoreProvider: String): StateStoreRemovingColumnFamiliesNotSupportedException = { new StateStoreRemovingColumnFamiliesNotSupportedException(stateStoreProvider) @@ -129,7 +134,14 @@ object StateStoreErrors { class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_MULTIPLE_COLUMN_FAMILIES", - messageParameters = Map("stateStoreProvider" -> stateStoreProvider)) + messageParameters = Map("stateStoreProvider" -> stateStoreProvider) + ) + +class StateStoreTTLNotSupportedException(stateStoreProvider: String) + extends SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_TTL", + messageParameters = Map("stateStoreProvider" -> stateStoreProvider) + ) class StateStoreRemovingColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 79422d6fa83ee..c57c0c2878576 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -34,6 +34,35 @@ object TransformWithStateSuiteUtils { val NUM_SHUFFLE_PARTITIONS = 5 } +class RunningCountStatefulProcessorZeroTTL + extends StatefulProcessor[String, String, (String, String)] + with Logging { + @transient private var _countState: ValueState[Long] = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode) : Unit = { + assert(getHandle.getQueryInfo().getBatchId >= 0) + _countState = getHandle.getValueState("countState", Encoders.scalaLong) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + val count = _countState.getOption().getOrElse(0L) + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(count) + Iterator((key, count.toString)) + } + } + + override def close(): Unit = {} +} class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)] with Logging { @transient protected var _countState: ValueState[Long] = _ @@ -250,6 +279,41 @@ class RunningCountMostRecentStatefulProcessor } } +class RunningCountMostRecentStatefulProcessorWithTTL + extends StatefulProcessor[String, (String, String), (String, String, String)] + with Logging { + @transient private var _countState: ValueState[Long] = _ + @transient private var _mostRecent: ValueState[String] = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode) : Unit = { + assert(getHandle.getQueryInfo().getBatchId >= 0) + _countState = getHandle.getValueState( + "countState", Encoders.scalaLong) + _mostRecent = getHandle.getValueState("mostRecent", Encoders.STRING) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, String)], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, String)] = { + val count = _countState.getOption().getOrElse(0L) + 1 + val mostRecent = _mostRecent.getOption().getOrElse("") + + var output = List[(String, String, String)]() + inputRows.foreach { row => + _mostRecent.update(row._2) + _countState.update(count) + output = (key, count.toString, mostRecent) :: output + } + output.iterator + } + + override def close(): Unit = {} +} + class MostRecentStatefulProcessorWithDeletion extends StatefulProcessor[String, (String, String), (String, String)] with Logging { @@ -569,6 +633,61 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + test("transformWithState - Zero duration TTL, should expire immediately") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[String] + + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorZeroTTL(), + TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), + OutputMode.Update()) + + // State should expire immediately, meaning each answer is independent + // of previous counts + testStream(result, OutputMode.Update())( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "1"), ("b", "1")), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "1"), ("b", "1")), + StopStream + ) + } + } + + test("transformWithState - multiple state variables with one TTL") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[(String, String)] + val stream1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessorWithTTL(), + TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), + OutputMode.Update()) + + // State should expire immediately, meaning each answer is independent + // of previous counts + testStream(stream1, OutputMode.Update())( + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + AddData(inputData, ("a", "str2"), ("b", "str3")), + CheckNewAnswer(("a", "1", "str1"), ("b", "1", "")), + AddData(inputData, ("a", "str4"), ("b", "str5")), + CheckNewAnswer(("a", "1", "str2"), ("b", "1", "str3")), + StopStream + ) + } + } + test("transformWithState - two input streams") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, From 4a007a0be979e876290b175b6d4d67e3b0dbc96c Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Fri, 22 Mar 2024 13:09:12 -0700 Subject: [PATCH 13/16] Fix existing testcases, and add TTL testcase. --- .../streaming/StateVariableTTLSupport.scala | 13 +- .../streaming/TransformWithStateSuite.scala | 119 ------------------ 2 files changed, 10 insertions(+), 122 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala index 01c3b8b0257f4..06ffb07bf8e46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming import java.time.Duration +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.execution.streaming.state.StateStore @@ -45,14 +46,15 @@ class SingleKeyTTLState( stateName: String, store: StateStore, batchTimestampMs: Option[Long], - eventTimeWatermarkMs: Option[Long], - state: StateVariableTTLSupport) - extends TTLState { + eventTimeWatermarkMs: Option[Long]) + extends TTLState + with Logging { import org.apache.spark.sql.execution.streaming.StateTTLSchema._ private val ttlColumnFamilyName = s"_ttl_$stateName" private val ttlKeyEncoder = UnsafeProjection.create(KEY_ROW_SCHEMA) + private var state: StateVariableTTLSupport = _ // empty row used for values private val EMPTY_ROW = @@ -83,6 +85,11 @@ class SingleKeyTTLState( } } } + + private[sql] def setStateVariable( + state: StateVariableTTLSupport): Unit = { + this.state = state + } } object StateTTL { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index c57c0c2878576..79422d6fa83ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -34,35 +34,6 @@ object TransformWithStateSuiteUtils { val NUM_SHUFFLE_PARTITIONS = 5 } -class RunningCountStatefulProcessorZeroTTL - extends StatefulProcessor[String, String, (String, String)] - with Logging { - @transient private var _countState: ValueState[Long] = _ - - override def init( - outputMode: OutputMode, - timeoutMode: TimeoutMode) : Unit = { - assert(getHandle.getQueryInfo().getBatchId >= 0) - _countState = getHandle.getValueState("countState", Encoders.scalaLong) - } - - override def handleInputRows( - key: String, - inputRows: Iterator[String], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { - val count = _countState.getOption().getOrElse(0L) + 1 - if (count == 3) { - _countState.clear() - Iterator.empty - } else { - _countState.update(count) - Iterator((key, count.toString)) - } - } - - override def close(): Unit = {} -} class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)] with Logging { @transient protected var _countState: ValueState[Long] = _ @@ -279,41 +250,6 @@ class RunningCountMostRecentStatefulProcessor } } -class RunningCountMostRecentStatefulProcessorWithTTL - extends StatefulProcessor[String, (String, String), (String, String, String)] - with Logging { - @transient private var _countState: ValueState[Long] = _ - @transient private var _mostRecent: ValueState[String] = _ - - override def init( - outputMode: OutputMode, - timeoutMode: TimeoutMode) : Unit = { - assert(getHandle.getQueryInfo().getBatchId >= 0) - _countState = getHandle.getValueState( - "countState", Encoders.scalaLong) - _mostRecent = getHandle.getValueState("mostRecent", Encoders.STRING) - } - - override def handleInputRows( - key: String, - inputRows: Iterator[(String, String)], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, String)] = { - val count = _countState.getOption().getOrElse(0L) + 1 - val mostRecent = _mostRecent.getOption().getOrElse("") - - var output = List[(String, String, String)]() - inputRows.foreach { row => - _mostRecent.update(row._2) - _countState.update(count) - output = (key, count.toString, mostRecent) :: output - } - output.iterator - } - - override def close(): Unit = {} -} - class MostRecentStatefulProcessorWithDeletion extends StatefulProcessor[String, (String, String), (String, String)] with Logging { @@ -633,61 +569,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - Zero duration TTL, should expire immediately") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - val inputData = MemoryStream[String] - - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorZeroTTL(), - TimeoutMode.NoTimeouts(), - TTLMode.NoTTL(), - OutputMode.Update()) - - // State should expire immediately, meaning each answer is independent - // of previous counts - testStream(result, OutputMode.Update())( - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), - AddData(inputData, "a", "b"), - CheckNewAnswer(("a", "1"), ("b", "1")), - AddData(inputData, "a", "b"), - CheckNewAnswer(("a", "1"), ("b", "1")), - StopStream - ) - } - } - - test("transformWithState - multiple state variables with one TTL") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - val inputData = MemoryStream[(String, String)] - val stream1 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new RunningCountMostRecentStatefulProcessorWithTTL(), - TimeoutMode.NoTimeouts(), - TTLMode.NoTTL(), - OutputMode.Update()) - - // State should expire immediately, meaning each answer is independent - // of previous counts - testStream(stream1, OutputMode.Update())( - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "1", "")), - AddData(inputData, ("a", "str2"), ("b", "str3")), - CheckNewAnswer(("a", "1", "str1"), ("b", "1", "")), - AddData(inputData, ("a", "str4"), ("b", "str5")), - CheckNewAnswer(("a", "1", "str2"), ("b", "1", "str3")), - StopStream - ) - } - } - test("transformWithState - two input streams") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, From 4e3445e1d9b77f43161c858e2920bdcd35510a68 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Fri, 22 Mar 2024 16:44:07 -0700 Subject: [PATCH 14/16] Improved documentation, added NERF error classes for user errors. --- .../streaming/StateVariableTTLSupport.scala | 123 ------------------ .../execution/streaming/state/RocksDB.scala | 4 - .../state/RocksDBStateStoreProvider.scala | 7 - .../streaming/state/StateStore.scala | 5 - 4 files changed, 139 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala deleted file mode 100644 index 06ffb07bf8e46..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableTTLSupport.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.streaming - -import java.time.Duration - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.execution.streaming.state.StateStore -import org.apache.spark.sql.streaming.TTLMode -import org.apache.spark.sql.types.{BinaryType, DataType, LongType, NullType, StructField, StructType} - -object StateTTLSchema { - val KEY_ROW_SCHEMA: StructType = new StructType() - .add("expirationMs", LongType) - .add("groupingKey", BinaryType) - val VALUE_ROW_SCHEMA: StructType = - StructType(Array(StructField("__dummy__", NullType))) -} - -trait StateVariableTTLSupport { - def clearIfExpired(groupingKey: Array[Byte]): Unit -} - -trait TTLState { - def clearExpiredState(): Unit -} - -class SingleKeyTTLState( - ttlMode: TTLMode, - stateName: String, - store: StateStore, - batchTimestampMs: Option[Long], - eventTimeWatermarkMs: Option[Long]) - extends TTLState - with Logging { - - import org.apache.spark.sql.execution.streaming.StateTTLSchema._ - - private val ttlColumnFamilyName = s"_ttl_$stateName" - private val ttlKeyEncoder = UnsafeProjection.create(KEY_ROW_SCHEMA) - private var state: StateVariableTTLSupport = _ - - // empty row used for values - private val EMPTY_ROW = - UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) - - store.createColFamilyIfAbsent(ttlColumnFamilyName, KEY_ROW_SCHEMA, 0, - VALUE_ROW_SCHEMA, isInternal = true) - - def upsertTTLForStateKey( - expirationMs: Long, - groupingKey: Array[Byte]): Unit = { - val encodedTtlKey = ttlKeyEncoder(InternalRow(expirationMs, groupingKey)) - store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName) - } - - override def clearExpiredState(): Unit = { - store.iterator(ttlColumnFamilyName).foreach { kv => - val expirationMs = kv.key.getLong(0) - val isExpired = StateTTL.isExpired(ttlMode, expirationMs, - batchTimestampMs, eventTimeWatermarkMs) - - if (isExpired) { - val groupingKey = kv.key.getBinary(1) - state.clearIfExpired(groupingKey) - - // TODO(sahnib) ; validate its safe to update inside iterator - store.remove(kv.key, ttlColumnFamilyName) - } - } - } - - private[sql] def setStateVariable( - state: StateVariableTTLSupport): Unit = { - this.state = state - } -} - -object StateTTL { - def calculateExpirationTimeForDuration( - ttlMode: TTLMode, - ttlDuration: Duration, - batchTimestampMs: Option[Long], - eventTimeWatermarkMs: Option[Long]): Long = { - if (ttlMode == TTLMode.ProcessingTimeTTL()) { - batchTimestampMs.get + ttlDuration.toMillis - } else if (ttlMode == TTLMode.EventTimeTTL()) { - eventTimeWatermarkMs.get + ttlDuration.toMillis - } else { - -1L - } - } - - def isExpired( - ttlMode: TTLMode, - expirationMs: Long, - batchTimestampMs: Option[Long], - eventTimeWatermarkMs: Option[Long]): Boolean = { - if (ttlMode == TTLMode.ProcessingTimeTTL()) { - batchTimestampMs.get > expirationMs - } else if (ttlMode == TTLMode.EventTimeTTL()) { - eventTimeWatermarkMs.get > expirationMs - } else { - false - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 1c1159f654546..1817104a5c223 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -342,10 +342,6 @@ class RocksDB( } } - def listColumnFamilies(): Seq[String] = { - Seq() - } - /** * Remove RocksDB column family, if exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index c6db9b5d000c7..c8537f2a6a5b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -257,13 +257,6 @@ private[sql] class RocksDBStateStoreProvider rocksDB.removeColFamilyIfExists(colFamilyName) keyValueEncoderMap.remove(colFamilyName) } - - /** Return a list of column family names */ - override def listColumnFamilies(): Seq[String] = { - verify(useColumnFamilies, "Column families are not supported in this store") - // turn keys of keyValueEncoderMap to Seq - rocksDB.listColumnFamilies() - } } override def init( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index c79c244dcf443..dd97aa5b9afca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -122,11 +122,6 @@ trait StateStore extends ReadStateStore { */ def removeColFamilyIfExists(colFamilyName: String): Unit - /** - * Return list of column family names. - */ - def listColumnFamilies(): Seq[String] = Seq(StateStore.DEFAULT_COL_FAMILY_NAME) - /** * Create column family with given name, if absent. */ From bf31813afce4bc728996df3e4d52e2fc323da814 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Sun, 17 Mar 2024 21:28:07 -0700 Subject: [PATCH 15/16] State TTL: Initial Commit --- .../streaming/StatefulProcessorTTLState.scala | 162 ++++++++++++++++++ .../execution/streaming/state/RocksDB.scala | 4 + .../state/RocksDBStateStoreProvider.scala | 7 + .../streaming/state/StateStore.scala | 5 + 4 files changed, 178 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorTTLState.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorTTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorTTLState.scala new file mode 100644 index 0000000000000..ce01a79e03502 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorTTLState.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.time.Duration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTTL, NoTTL, ProcessingTimeTTL} +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.TTLMode +import org.apache.spark.sql.types.{BinaryType, DataType, LongType, NullType, StructField, StructType} + +class StatefulProcessorTTLState( + ttlMode: TTLMode, + stateName: String, + store: StateStore, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long], + createColumnFamily: Boolean = true) { + + private val ttlColumnFamilyName = s"_ttl_$stateName" + + private val schemaForKeyRow = new StructType() + .add("expirationMs", LongType) + .add("groupingKey", BinaryType) + .add("userKey", BinaryType) + private val schemaForValueRow: StructType = + StructType(Array(StructField("__dummy__", NullType))) + + private val ttlKeyEncoder = UnsafeProjection.create(schemaForKeyRow) + + // empty row used for values + private val EMPTY_ROW = + UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) + + validate() + if (createColumnFamily) { + store.createColFamilyIfAbsent(ttlColumnFamilyName, schemaForKeyRow, 0, + schemaForValueRow, isInternal = true) + } + + private def validate(): Unit = { + ttlMode match { + case NoTTL => + throw new RuntimeException() + case ProcessingTimeTTL if batchTimestampMs.isEmpty => + throw new IllegalStateException() + case EventTimeTTL if eventTimeWatermarkMs.isEmpty => + throw new IllegalStateException() + case _ => + } + } + + private def expiredKeysIterator(): Iterator[InternalRow] = { + // TODO(sahnib): Need to merge Anish's changes + store.iterator(ttlColumnFamilyName).foreach { kv => + + } + + Iterator.empty + } + + def upsertTTLForStateKey( + expirationMs: Long, + groupingKey: Array[Byte], + userKey: Option[Array[Byte]]): Unit = { + val encodedTtlKey = ttlKeyEncoder(InternalRow(expirationMs, + groupingKey, userKey.orNull)) + store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName) + } +} + +object StatefulProcessorTTLState { + def apply( + ttlMode: TTLMode, + stateName: String, + store: StateStore, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): StatefulProcessorTTLState = { + new StatefulProcessorTTLState(ttlMode, stateName, store, batchTimestampMs, eventTimeWatermarkMs) + } + + def calculateExpirationTimeForDuration( + ttlMode: TTLMode, + ttlDuration: Duration, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Long = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get + ttlDuration.toMillis + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get + ttlDuration.toMillis + } else { + -1L + } + } + + def getCurrentExpirationTime( + ttlMode: TTLMode, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Long = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get + } else { + -1L + } + } + + def isExpired( + ttlMode: TTLMode, + expirationMs: Long, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Boolean = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get > expirationMs + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get > expirationMs + } else { + false + } + } + + def encodedTtlModeValue(ttLMode: TTLMode): Short = { + ttLMode match { + case NoTTL => + 0 + case ProcessingTimeTTL => + 1 + case EventTimeTTL => + 2 + } + } + + def decodedTtlMode(encodedVal: Short): TTLMode = { + encodedVal match { + case 0 => + TTLMode.NoTTL() + case 1 => + TTLMode.ProcessingTimeTTL() + case 2 => + TTLMode.EventTimeTTL() + case _ => + throw new IllegalStateException("encodedTtlValue should be <= 2") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 1817104a5c223..1c1159f654546 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -342,6 +342,10 @@ class RocksDB( } } + def listColumnFamilies(): Seq[String] = { + Seq() + } + /** * Remove RocksDB column family, if exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index c8537f2a6a5b1..c6db9b5d000c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -257,6 +257,13 @@ private[sql] class RocksDBStateStoreProvider rocksDB.removeColFamilyIfExists(colFamilyName) keyValueEncoderMap.remove(colFamilyName) } + + /** Return a list of column family names */ + override def listColumnFamilies(): Seq[String] = { + verify(useColumnFamilies, "Column families are not supported in this store") + // turn keys of keyValueEncoderMap to Seq + rocksDB.listColumnFamilies() + } } override def init( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index dd97aa5b9afca..c79c244dcf443 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -122,6 +122,11 @@ trait StateStore extends ReadStateStore { */ def removeColFamilyIfExists(colFamilyName: String): Unit + /** + * Return list of column family names. + */ + def listColumnFamilies(): Seq[String] = Seq(StateStore.DEFAULT_COL_FAMILY_NAME) + /** * Create column family with given name, if absent. */ From 7d03672f45bcc609827a162a1e7fbbdef66b05fc Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 27 Mar 2024 10:39:01 -0700 Subject: [PATCH 16/16] init --- .../spark/sql/streaming/ListState.scala | 8 +- .../execution/streaming/ListStateImpl.scala | 133 ++++++++++++++++-- .../streaming/StatefulProcessorTTLState.scala | 6 +- 3 files changed, 128 insertions(+), 19 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala index 0e2d6cc3778c6..c3d1ebca4a652 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.streaming +import java.time.Duration + import org.apache.spark.annotation.{Evolving, Experimental} @Experimental @@ -33,13 +35,13 @@ private[sql] trait ListState[S] extends Serializable { def get(): Iterator[S] /** Update the value of the list. */ - def put(newState: Array[S]): Unit + def put(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit /** Append an entry to the list */ - def appendValue(newState: S): Unit + def appendValue(newState: S, ttlDuration: Duration = Duration.ZERO): Unit /** Append an entire list to the existing value */ - def appendList(newState: Array[S]): Unit + def appendList(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit /** Removes this state for the given grouping key. */ def clear(): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 662bef5716ea2..14d9a877ea054 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -16,12 +16,15 @@ */ package org.apache.spark.sql.execution.streaming +import java.time.Duration + import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} -import org.apache.spark.sql.streaming.ListState +import org.apache.spark.sql.streaming.{ListState, TTLMode} /** * Provides concrete implementation for list of values associated with a state variable @@ -37,12 +40,18 @@ class ListStateImpl[S]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], - valEncoder: Encoder[S]) - extends ListState[S] with Logging { + valEncoder: Encoder[S], + ttlMode: TTLMode, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]) + extends ListState[S] + with Logging + with StateVariableWithTTLSupport { private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) + private[sql] var ttlState: Option[SingleKeyTTLState] = None store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true) @@ -59,29 +68,65 @@ class ListStateImpl[S]( * empty iterator is returned. */ override def get(): Iterator[S] = { + val encodedKey = stateTypesEncoder.encodeGroupingKey() val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName) + var currentRow: UnsafeRow = null + var shouldGetNextValidRow = true + var isFirst = true + new Iterator[S] { override def hasNext: Boolean = { - unsafeRowValuesIterator.hasNext + if (shouldGetNextValidRow) { + getNextValidRow() + } + currentRow != null } override def next(): S = { - val valueUnsafeRow = unsafeRowValuesIterator.next() - stateTypesEncoder.decodeValue(valueUnsafeRow) + if (shouldGetNextValidRow) { + getNextValidRow() + } + if (currentRow == null) { + throw new NoSuchElementException("Iterator is at the end") + } + shouldGetNextValidRow = true + stateTypesEncoder.decodeValue(currentRow) + } + + // sets currentRow to a valid state, where we are + // pointing to a non-expired row + private def getNextValidRow(): Unit = { + assert(shouldGetNextValidRow) + while (unsafeRowValuesIterator.hasNext && (isFirst || isExpired(currentRow))) { + isFirst = false + currentRow = unsafeRowValuesIterator.next() + } + // in this case, we have iterated to the end, and there are no + // non-expired values + if (isExpired(currentRow)) { + currentRow = null + } + shouldGetNextValidRow = false } } } /** Update the value of the list. */ - override def put(newState: Array[S]): Unit = { + override def put(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit = { validateNewState(newState) val encodedKey = stateTypesEncoder.encodeGroupingKey() var isFirst = true + var expirationMs: Long = -1 + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + expirationMs = StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } + newState.foreach { v => - val encodedValue = stateTypesEncoder.encodeValue(v) + val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) if (isFirst) { store.put(encodedKey, encodedValue, stateName) isFirst = false @@ -89,24 +134,42 @@ class ListStateImpl[S]( store.merge(encodedKey, encodedValue, stateName) } } + ttlState.foreach(_.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey())) } /** Append an entry to the list. */ - override def appendValue(newState: S): Unit = { + override def appendValue(newState: S, ttlDuration: Duration = Duration.ZERO): Unit = { StateStoreErrors.requireNonNullStateValue(newState, stateName) + var expirationMs: Long = -1 + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + expirationMs = StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } + val encodedValue = stateTypesEncoder.encodeValue(newState, expirationMs) store.merge(stateTypesEncoder.encodeGroupingKey(), - stateTypesEncoder.encodeValue(newState), stateName) + encodedValue, stateName) + ttlState.foreach(_.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey())) } /** Append an entire list to the existing value. */ - override def appendList(newState: Array[S]): Unit = { + override def appendList(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit = { validateNewState(newState) + var expirationMs: Long = -1 + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + expirationMs = StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } + val encodedKey = stateTypesEncoder.encodeGroupingKey() newState.foreach { v => - val encodedValue = stateTypesEncoder.encodeValue(v) + val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) store.merge(encodedKey, encodedValue, stateName) } + ttlState.foreach(_.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey())) } /** Remove this state. */ @@ -122,4 +185,48 @@ class ListStateImpl[S]( StateStoreErrors.requireNonNullStateValue(v, stateName) } } - } + + /** + * Clears the user state associated with this grouping key + * if it has expired. This function is called by Spark to perform + * cleanup at the end of transformWithState processing. + * + * Spark uses a secondary index to determine if the user state for + * this grouping key has expired. However, its possible that the user + * has updated the TTL and secondary index is out of date. Implementations + * must validate that the user State has actually expired before cleanup based + * on their own State data. + * + * @param groupingKey grouping key for which cleanup should be performed. + */ + override def clearIfExpired(groupingKey: Array[Byte]): Unit = { + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName) + clear() + var isFirst = true + + unsafeRowValuesIterator.foreach { encodedValue => + val encodedGroupingKey = stateTypesEncoder.encodeSerializedGroupingKey(groupingKey) + if (encodedValue != null) { + if (isExpired(encodedValue)) { + store.remove(encodedGroupingKey, stateName) + } else { + if (isFirst) { + store.put(encodedKey, encodedValue, stateName) + isFirst = false + } else { + store.merge(encodedKey, encodedValue, stateName) + } + } + } + } + } + + private def isExpired(valueRow: UnsafeRow): Boolean = { + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow) + val isExpired = expirationMs.map( + StateTTL.isExpired(ttlMode, _, batchTimestampMs, eventTimeWatermarkMs)) + + isExpired.isDefined && isExpired.get + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorTTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorTTLState.scala index ce01a79e03502..4cdfe9e2f9ecc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorTTLState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorTTLState.scala @@ -21,7 +21,7 @@ import java.time.Duration import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTTL, NoTTL, ProcessingTimeTTL} -import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.TTLMode import org.apache.spark.sql.types.{BinaryType, DataType, LongType, NullType, StructField, StructType} @@ -50,8 +50,8 @@ class StatefulProcessorTTLState( validate() if (createColumnFamily) { - store.createColFamilyIfAbsent(ttlColumnFamilyName, schemaForKeyRow, 0, - schemaForValueRow, isInternal = true) + store.createColFamilyIfAbsent(ttlColumnFamilyName, schemaForKeyRow, + schemaForValueRow, PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1), isInternal = true) } private def validate(): Unit = {