diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 11c8204d2c93..065385855431 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3541,6 +3541,12 @@ ], "sqlState" : "0A000" }, + "STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE" : { + "message" : [ + "State store operation= on state= does not support TTL in NoTTL() mode." + ], + "sqlState" : "42802" + }, "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE" : { "message" : [ "Failed to perform stateful processor operation= with invalid handle state=." @@ -3559,6 +3565,13 @@ ], "sqlState" : "42802" }, + "STATEFUL_PROCESSOR_CANNOT_USE_TTL_DURATION_IN_EVENT_TIME_TTL_MODE" : { + "message" : [ + "TTL duration is not allowed for event time ttl expiration on State store operation= on state=.", + "Use absolute expiration time instead." + ], + "sqlState" : "42802" + }, "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : { "message" : [ "Failed to create column family with unsupported starting character and name=." @@ -4353,6 +4366,11 @@ "Removing column families with is not supported." ] }, + "STATE_STORE_TTL" : { + "message" : [ + "State TTL with is not supported. Please use RocksDBStateStoreProvider." + ] + }, "TABLE_OPERATION" : { "message" : [ "Table does not support . Please check the current catalog and namespace to make sure the qualified table name is expected, and also check the catalog implementation which is configured by \"spark.sql.catalog\"." diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index e5dd67ccf787..5c0f4a7e8c81 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.ScalarUserDefinedFunction import org.apache.spark.sql.functions.col -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode, TTLMode} /** * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -830,12 +830,15 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { * Instance of statefulProcessor whose functions will be invoked by the operator. * @param timeoutMode * The timeout mode of the stateful processor. + * @param ttlMode + * The ttlMode to evict user state on ttl expiration. * @param outputMode * The output mode of the stateful processor. Defaults to APPEND mode. */ def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode = OutputMode.Append()): Dataset[U] = { throw new UnsupportedOperationException } diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 7b20dfb6bce5..94dfe20af56e 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -60,6 +60,8 @@ files="sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java"/> + ` is not supported. Removing column families with `` is not supported. +## STATE_STORE_TTL + +State TTL with `` is not supported. Please use RocksDBStateStoreProvider. + ## TABLE_OPERATION Table `` does not support ``. Please check the current catalog and namespace to make sure the qualified table name is expected, and also check the catalog implementation which is configured by "spark.sql.catalog". diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 85b9e85ac420..25f068a3d3a2 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -2150,6 +2150,12 @@ The SQL config `` cannot be found. Please verify that the config exists Star (*) is not allowed in a select list when GROUP BY an ordinal position is used. +### STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +State store operation=`` on state=`` does not support TTL in NoTTL() mode. + ### STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE [SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) @@ -2168,6 +2174,13 @@ Failed to perform stateful processor operation=`` with invalid ti Cannot re-initialize state on the same grouping key during initial state handling for stateful processor. Invalid grouping key=``. +### STATEFUL_PROCESSOR_CANNOT_USE_TTL_DURATION_IN_EVENT_TIME_TTL_MODE + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +TTL duration is not allowed for event time ttl expiration on State store operation=`` on state=``. +Use absolute expiration time instead. + ### STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS [SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java new file mode 100644 index 000000000000..06af92dc1321 --- /dev/null +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.catalyst.plans.logical.*; + +/** + * Represents the type of ttl modes possible for the Dataset operations + * {@code transformWithState}. + */ +@Experimental +@Evolving +public class TTLMode { + + /** + * Specifies that there is no TTL for the user state. User state would not + * be cleaned up by Spark automatically. + */ + public static final TTLMode NoTTL() { + return NoTTL$.MODULE$; + } + + /** + * Specifies that all ttl durations for user state are in processing time. + */ + public static final TTLMode ProcessingTimeTTL() { return ProcessingTimeTTL$.MODULE$; } + + /** + * Specifies that all ttl durations for user state are in event time. + */ + public static final TTLMode EventTimeTTL() { return EventTimeTTL$.MODULE$; } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala new file mode 100644 index 000000000000..4b43060e0d35 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.streaming.TTLMode + +/** TTL types used in tranformWithState operator */ +case object NoTTL extends TTLMode + +case object ProcessingTimeTTL extends TTLMode + +case object EventTimeTTL extends TTLMode diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala index 0e2d6cc3778c..c3d1ebca4a65 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.streaming +import java.time.Duration + import org.apache.spark.annotation.{Evolving, Experimental} @Experimental @@ -33,13 +35,13 @@ private[sql] trait ListState[S] extends Serializable { def get(): Iterator[S] /** Update the value of the list. */ - def put(newState: Array[S]): Unit + def put(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit /** Append an entry to the list */ - def appendValue(newState: S): Unit + def appendValue(newState: S, ttlDuration: Duration = Duration.ZERO): Unit /** Append an entire list to the existing value */ - def appendList(newState: Array[S]): Unit + def appendList(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit /** Removes this state for the given grouping key. */ def clear(): Unit diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala index 1a61972f0ed0..fdec703412a8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala @@ -44,7 +44,8 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable { */ def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit /** * Function that will allow users to interact with input data rows along with the grouping key diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala index 560188a0ff62..30f2d9000ecc 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala @@ -30,10 +30,11 @@ import org.apache.spark.sql.Encoder private[sql] trait StatefulProcessorHandle extends Serializable { /** - * Function to create new or return existing single value state variable of given type + * Function to create new or return existing single value state variable of given type. * The user must ensure to call this function only within the `init()` method of the * StatefulProcessor. - * @param stateName - name of the state variable + * + * @param stateName - name of the state variable * @param valEncoder - SQL encoder for state variable * @tparam T - type of state variable * @return - instance of ValueState of type T that can be used to store state persistently diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala index 9c707c8308ab..95fc61980fff 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import java.io.Serializable +import java.time.Duration import org.apache.spark.annotation.{Evolving, Experimental} @@ -42,8 +43,22 @@ private[sql] trait ValueState[S] extends Serializable { /** Get the state if it exists as an option and None otherwise */ def getOption(): Option[S] - /** Update the value of the state. */ - def update(newState: S): Unit + /** + * Update the value of the state. + * @param newState the new value + * @param ttlDuration set the ttl to current batch processing time + * (for processing time TTL mode) plus ttlDuration + */ + def update(newState: S, ttlDuration: Duration = Duration.ZERO): Unit + + + /** + * Update the value of the state. + * + * @param newState the new value + * @param expirationMs set the ttl to expirationMs (processingTime or eventTime) + */ + def update(newState: S, expirationMs: Long): Unit /** Remove this state. */ def clear(): Unit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index b2c443a8cce0..ff7c8fb3df4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode, TTLMode} import org.apache.spark.sql.types._ object CatalystSerde { @@ -574,6 +574,7 @@ object TransformWithState { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[K, V, U], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, child: LogicalPlan): LogicalPlan = { @@ -584,6 +585,7 @@ object TransformWithState { groupingAttributes, dataAttributes, statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]], + ttlMode, timeoutMode, outputMode, keyEncoder.asInstanceOf[ExpressionEncoder[Any]], @@ -605,6 +607,7 @@ object TransformWithState { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[K, V, U], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, child: LogicalPlan, @@ -618,6 +621,7 @@ object TransformWithState { groupingAttributes, dataAttributes, statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]], + ttlMode, timeoutMode, outputMode, keyEncoder.asInstanceOf[ExpressionEncoder[Any]], @@ -639,6 +643,7 @@ case class TransformWithState( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 95ad973aee51..10586cd65963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.internal.TypedAggUtils -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode, TTLMode} /** * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -656,12 +656,14 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the * operator. * @param timeoutMode The timeout mode of the stateful processor. + * @param ttlMode The ttlMode to evict user state on ttl expiration * @param outputMode The output mode of the stateful processor. Defaults to APPEND mode. * */ private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode = OutputMode.Append()): Dataset[U] = { Dataset[U]( sparkSession, @@ -669,6 +671,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, child = logicalPlan @@ -693,6 +696,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( */ private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { @@ -702,6 +706,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, child = logicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cc212d99f299..e124c6f2edc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -751,7 +751,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case TransformWithState( keyDeserializer, valueDeserializer, groupingAttributes, - dataAttributes, statefulProcessor, timeoutMode, outputMode, + dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputAttr, child, hasInitialState, initialStateGroupingAttrs, initialStateDataAttrs, initialStateDeserializer, initialState) => @@ -761,6 +761,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, keyEncoder, @@ -925,15 +926,15 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { hasInitialState, planLater(initialState), planLater(child) ) :: Nil case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes, - dataAttributes, statefulProcessor, timeoutMode, outputMode, keyEncoder, + dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputObjAttr, child, hasInitialState, initialStateGroupingAttrs, initialStateDataAttrs, initialStateDeserializer, initialState) => TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, valueDeserializer, - groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode, + groupingAttributes, dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputObjAttr, planLater(child), hasInitialState, initialStateGroupingAttrs, initialStateDataAttrs, - initialStateDeserializer, planLater(initialState)) :: Nil + initialStateDeserializer, planLater (initialState)) :: Nil case _: FlatMapGroupsInPandasWithState => // TODO(SPARK-40443): support applyInPandasWithState in batch query diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 662bef5716ea..3a7ed530fad4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -16,10 +16,12 @@ */ package org.apache.spark.sql.execution.streaming +import java.time.Duration + import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.ListState @@ -74,7 +76,7 @@ class ListStateImpl[S]( } /** Update the value of the list. */ - override def put(newState: Array[S]): Unit = { + override def put(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit = { validateNewState(newState) val encodedKey = stateTypesEncoder.encodeGroupingKey() @@ -92,14 +94,14 @@ class ListStateImpl[S]( } /** Append an entry to the list. */ - override def appendValue(newState: S): Unit = { + override def appendValue(newState: S, ttlDuration: Duration = Duration.ZERO): Unit = { StateStoreErrors.requireNonNullStateValue(newState, stateName) store.merge(stateTypesEncoder.encodeGroupingKey(), stateTypesEncoder.encodeValue(newState), stateName) } /** Append an entire list to the existing value. */ - override def appendList(newState: Array[S]): Unit = { + override def appendList(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit = { validateNewState(newState) val encodedKey = stateTypesEncoder.encodeGroupingKey() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala new file mode 100644 index 000000000000..094318ae0c31 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.time.Duration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.streaming.{ListState, TTLMode} + +/** + * Provides concrete implementation for list of values associated with a state variable + * used in the streaming transformWithState operator. + * + * @param store - reference to the StateStore instance to be used for storing state + * @param stateName - name of logical state partition + * @param keyEnc - Spark SQL encoder for key + * @param valEncoder - Spark SQL encoder for value + * @tparam S - data type of object that will be stored in the list + */ +class ListStateImplWithTTL[S]( + store: StateStore, + stateName: String, + keyExprEnc: ExpressionEncoder[Any], + valEncoder: Encoder[S], + ttlMode: TTLMode, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]) + extends ListState[S] + with Logging + with StateVariableWithTTLSupport { + + private val keySerializer = keyExprEnc.createSerializer() + private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) + private[sql] var ttlState: SingleKeyTTLStateImpl = _ + + initialize() + private def initialize(): Unit = { + store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true) + + if (ttlMode != TTLMode.NoTTL()) { + ttlState = new SingleKeyTTLStateImpl(ttlMode, stateName, store, + batchTimestampMs, eventTimeWatermarkMs) + } + } + /** Whether state exists or not. */ + override def exists(): Boolean = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val stateValue = store.get(encodedGroupingKey, stateName) + stateValue != null + } + + /** + * Get the state value if it exists. If the state does not exist in state store, an + * empty iterator is returned. + */ + override def get(): Iterator[S] = { + + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName) + var currentRow: UnsafeRow = null + + new Iterator[S] { + override def hasNext: Boolean = { + if (currentRow == null) { + setNextValidRow() + } + logError(s"### hasNext: currentRow is null = ${currentRow == null}") + currentRow != null + } + + override def next(): S = { + if (currentRow == null) { + setNextValidRow() + } + if (currentRow == null) { + throw new NoSuchElementException("Iterator is at the end") + } + logError(s"### in get for ListState") + val result = stateTypesEncoder.decodeValue(currentRow) + currentRow = null + logError(s"### result is null ${result == null}") + result + } + + // sets currentRow to a valid state, where we are + // pointing to a non-expired row + private def setNextValidRow(): Unit = { + assert(currentRow == null) + logError(s"### at the top of setNextValidRow, hasNext = ${unsafeRowValuesIterator.hasNext}") + if (unsafeRowValuesIterator.hasNext) { + currentRow = unsafeRowValuesIterator.next() + return + } else { + currentRow = null + return + } + + while (unsafeRowValuesIterator.hasNext && (currentRow == null || isExpired(currentRow))) { + // log each of the conditions at the top of the while loop + logError(s"### unsafeRowValuesIterator.hasNext = ${unsafeRowValuesIterator.hasNext}," + + s" currentRow is null = ${currentRow == null}, isExpired = ${isExpired(currentRow)}") + currentRow = unsafeRowValuesIterator.next() + } + // in this case, we have iterated to the end, and there are no + // non-expired values + if (currentRow != null && isExpired(currentRow)) { + logError(s"### setting currentRow to null as it is expired") + currentRow = null + } + logError(s"### setNextValidRow: currentRow is null = ${currentRow == null}") + } + } + } + + /** Update the value of the list. */ + override def put(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit = { + validateNewState(newState) + logError(s"### in put for ListState") + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + var isFirst = true + + val expirationMs = + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } else { + -1 + } + logError(s"### listState expirationMs: ${expirationMs}") + newState.foreach { v => + val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) + logError(s"### in put loop for ListState") + if (isFirst) { + store.put(encodedKey, encodedValue, stateName) + isFirst = false + } else { + store.merge(encodedKey, encodedValue, stateName) + } + } + ttlState.upsertTTLForStateKey(expirationMs, + serializedGroupingKey) + } + + /** Update the value of the list. */ + def put(newState: Array[S], expirationMs: Long): Unit = { + validateNewState(newState) + logError(s"### in put for ListState") + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + var isFirst = true + + newState.foreach { v => + val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) + logError(s"### in put loop for ListState") + if (isFirst) { + store.put(encodedKey, encodedValue, stateName) + isFirst = false + } else { + store.merge(encodedKey, encodedValue, stateName) + } + } + ttlState.upsertTTLForStateKey(expirationMs, + serializedGroupingKey) + } + + /** Append an entry to the list. */ + override def appendValue(newState: S, ttlDuration: Duration = Duration.ZERO): Unit = { + StateStoreErrors.requireNonNullStateValue(newState, stateName) + val expirationMs = + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } else { + -1 + } + val encodedValue = stateTypesEncoder.encodeValue(newState, expirationMs) + store.merge(stateTypesEncoder.encodeGroupingKey(), + encodedValue, stateName) + ttlState.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey()) + } + + /** Append an entire list to the existing value. */ + override def appendList(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit = { + validateNewState(newState) + + val expirationMs = + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } else { + -1 + } + + val encodedKey = stateTypesEncoder.encodeGroupingKey() + newState.foreach { v => + val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) + store.merge(encodedKey, encodedValue, stateName) + } + ttlState.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey()) + } + + /** Append an entry to the list. */ + def appendValue(newState: S, expirationMs: Long): Unit = { + StateStoreErrors.requireNonNullStateValue(newState, stateName) + + val encodedValue = stateTypesEncoder.encodeValue(newState, expirationMs) + store.merge(stateTypesEncoder.encodeGroupingKey(), + encodedValue, stateName) + ttlState.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey()) + } + + /** Append an entire list to the existing value. */ + def appendList(newState: Array[S], expirationMs: Long): Unit = { + validateNewState(newState) + + val encodedKey = stateTypesEncoder.encodeGroupingKey() + newState.foreach { v => + val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) + store.merge(encodedKey, encodedValue, stateName) + } + ttlState.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey()) + } + + /** Remove this state. */ + override def clear(): Unit = { + store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + } + + private def validateNewState(newState: Array[S]): Unit = { + StateStoreErrors.requireNonNullStateValue(newState, stateName) + StateStoreErrors.requireNonEmptyListStateValue(newState, stateName) + + newState.foreach { v => + StateStoreErrors.requireNonNullStateValue(v, stateName) + } + } + + /** + * Clears the user state associated with this grouping key + * if it has expired. This function is called by Spark to perform + * cleanup at the end of transformWithState processing. + * + * Spark uses a secondary index to determine if the user state for + * this grouping key has expired. However, its possible that the user + * has updated the TTL and secondary index is out of date. Implementations + * must validate that the user State has actually expired before cleanup based + * on their own State data. + * + * @param groupingKey grouping key for which cleanup should be performed. + */ + override def clearIfExpired(groupingKey: Array[Byte]): Unit = { + val encodedGroupingKey = stateTypesEncoder.encodeSerializedGroupingKey(groupingKey) + val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, stateName) + // We clear the list, and use the iterator to put back all of the non-expired values + store.remove(encodedGroupingKey, stateName) + var isFirst = true + unsafeRowValuesIterator.foreach { encodedValue => + if (!isExpired(encodedValue)) { + if (isFirst) { + store.put(encodedGroupingKey, encodedValue, stateName) + isFirst = false + } else { + store.merge(encodedGroupingKey, encodedValue, stateName) + } + } + } + } + + private def isExpired(valueRow: UnsafeRow): Boolean = { + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow) + val isExpired = expirationMs.map( + StateTTL.isExpired(ttlMode, _, batchTimestampMs, eventTimeWatermarkMs)) + + isExpired.isDefined && isExpired.get + } + + /* + * Internal methods to probe state for testing. The below methods exist for unit tests + * to read the state ttl values, and ensure that values are persisted correctly in + * the underlying state store. + */ + + /** + * Retrieves the value from State even if its expired. This method is used + * in tests to read the state store value, and ensure if its cleaned up at the + * end of the micro-batch. + */ + private[sql] def getWithoutEnforcingTTL(): Iterator[S] = { + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName) + new Iterator[S] { + override def hasNext: Boolean = { + unsafeRowValuesIterator.hasNext + } + override def next(): S = { + val valueUnsafeRow = unsafeRowValuesIterator.next() + stateTypesEncoder.decodeValue(valueUnsafeRow) + } + } + } + + /** + * Read the ttl value associated with the grouping key. + */ + private[sql] def getTTLValues(): Iterator[Option[Long]] = { + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName) + new Iterator[Option[Long]] { + override def hasNext: Boolean = { + unsafeRowValuesIterator.hasNext + } + + override def next(): Option[Long] = { + val valueUnsafeRow = unsafeRowValuesIterator.next() + stateTypesEncoder.decodeTtlExpirationMs(valueUnsafeRow) + } + } + } + + /** + * Get all ttl values stored in ttl state for current implicit + * grouping key. + */ + private[sql] def getValuesInTTLState(): Iterator[Long] = { + + val ttlIterator = ttlState.iterator() + val implicitGroupingKey = stateTypesEncoder.serializeGroupingKey() + var nextValue: Option[Long] = None + + new Iterator[Long] { + override def hasNext: Boolean = { + while (nextValue.isEmpty && ttlIterator.hasNext) { + val nextTtlValue = ttlIterator.next() + val groupingKey = nextTtlValue.groupingKey + + if (groupingKey sameElements implicitGroupingKey) { + nextValue = Some(nextTtlValue.expirationMs) + } + } + + nextValue.isDefined + } + + override def next(): Long = { + val result = nextValue.get + nextValue = None + + result + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index d2ccd0a77807..c58f32ed756d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -45,7 +45,7 @@ class MapStateImpl[K, V]( /** Whether state exists or not. */ override def exists(): Boolean = { - !store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).isEmpty + store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).nonEmpty } /** Get the state value if it exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index 1d41db896cdf..c78d368f28ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -17,17 +17,22 @@ package org.apache.spark.sql.execution.streaming +import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.state.StateStoreErrors -import org.apache.spark.sql.types.{BinaryType, StructType} +import org.apache.spark.sql.types.{BinaryType, LongType, StructType} -object StateKeyValueRowSchema { +object TransformWithStateKeyValueRowSchema { val KEY_ROW_SCHEMA: StructType = new StructType().add("key", BinaryType) - val VALUE_ROW_SCHEMA: StructType = new StructType().add("value", BinaryType) + val VALUE_ROW_SCHEMA: StructType = new StructType() + .add("value", BinaryType) + val VALUE_ROW_SCHEMA_WITH_TTL: StructType = new StructType() + .add("value", BinaryType) + .add("ttlExpirationMs", LongType) } /** @@ -49,12 +54,17 @@ object StateKeyValueRowSchema { class StateTypesEncoder[GK, V]( keySerializer: Serializer[GK], valEncoder: Encoder[V], - stateName: String) { - import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema._ + stateName: String, + hasTtl: Boolean) extends Logging { + import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema._ /** Variables reused for conversions between byte array and UnsafeRow */ private val keyProjection = UnsafeProjection.create(KEY_ROW_SCHEMA) - private val valueProjection = UnsafeProjection.create(VALUE_ROW_SCHEMA) + private val valueProjection = if (hasTtl) { + UnsafeProjection.create(VALUE_ROW_SCHEMA_WITH_TTL) + } else { + UnsafeProjection.create(VALUE_ROW_SCHEMA) + } /** Variables reused for value conversions between spark sql and object */ private val valExpressionEnc = encoderFor(valEncoder) @@ -65,22 +75,49 @@ class StateTypesEncoder[GK, V]( // TODO: validate places that are trying to encode the key and check if we can eliminate/ // add caching for some of these calls. def encodeGroupingKey(): UnsafeRow = { + val keyRow = keyProjection(InternalRow(serializeGroupingKey())) + keyRow + } + + /** + * Encodes the provided grouping key into Spark UnsafeRow. + * + * @param groupingKeyBytes serialized grouping key byte array + * @return encoded UnsafeRow + */ + def encodeSerializedGroupingKey( + groupingKeyBytes: Array[Byte]): UnsafeRow = { + val keyRow = keyProjection(InternalRow(groupingKeyBytes)) + keyRow + } + + def serializeGroupingKey(): Array[Byte] = { val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption if (keyOption.isEmpty) { throw StateStoreErrors.implicitKeyNotFound(stateName) } - val groupingKey = keyOption.get.asInstanceOf[GK] - val keyByteArr = keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() - val keyRow = keyProjection(InternalRow(keyByteArr)) - keyRow + keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() } + /** + * Encode the specified value in Spark UnsafeRow with no ttl. + * The ttl expiration will be set to -1, specifying no TTL. + */ def encodeValue(value: V): UnsafeRow = { val objRow: InternalRow = objToRowSerializer.apply(value) val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() - val valRow = valueProjection(InternalRow(bytes)) - valRow + valueProjection(InternalRow(bytes)) + } + + /** + * Encode the specified value in Spark UnsafeRow + * with provided ttl expiration. + */ + def encodeValue(value: V, expirationMs: Long = -1): UnsafeRow = { + val objRow: InternalRow = objToRowSerializer.apply(value) + val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() + valueProjection(InternalRow(bytes, expirationMs)) } def decodeValue(row: UnsafeRow): V = { @@ -89,14 +126,29 @@ class StateTypesEncoder[GK, V]( val value = rowToObjDeserializer.apply(reusedValRow) value } + + /** + * Decode the ttl information out of Value row. If the ttl has + * not been set (-1L specifies no user defined value), the API will + * return None. + */ + def decodeTtlExpirationMs(row: UnsafeRow): Option[Long] = { + val expirationMs = row.getLong(1) + if (expirationMs == -1) { + None + } else { + Some(expirationMs) + } + } } object StateTypesEncoder { def apply[GK, V]( keySerializer: Serializer[GK], valEncoder: Encoder[V], - stateName: String): StateTypesEncoder[GK, V] = { - new StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName) + stateName: String, + hasTtl: Boolean = false): StateTypesEncoder[GK, V] = { + new StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName, hasTtl) } } @@ -105,8 +157,9 @@ class CompositeKeyStateEncoder[GK, K, V]( userKeyEnc: Encoder[K], valEncoder: Encoder[V], schemaForCompositeKeyRow: StructType, - stateName: String) - extends StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName) { + stateName: String, + hasTtl: Boolean = false) + extends StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName, hasTtl) { private val compositeKeyProjection = UnsafeProjection.create(schemaForCompositeKeyRow) private val reusedKeyRow = new UnsafeRow(userKeyEnc.schema.fields.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 5f3b794fd117..234d87f183c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.streaming +import java.util import java.util.UUID import org.apache.spark.TaskContext @@ -24,7 +25,7 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TimeoutMode, TTLMode, ValueState} import org.apache.spark.util.Utils /** @@ -77,14 +78,23 @@ class StatefulProcessorHandleImpl( store: StateStore, runId: UUID, keyEncoder: ExpressionEncoder[Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, - isStreaming: Boolean = true) + isStreaming: Boolean = true, + batchTimestampMs: Option[Long] = None, + eventTimeWatermarkMs: Option[Long] = None) extends StatefulProcessorHandle with Logging { import StatefulProcessorHandleState._ + /** + * Stores all the active ttl states, and is used to cleanup expired values + * in [[doTtlCleanup()]] function. + */ + private val ttlStates: util.List[TTLState] = new util.ArrayList[TTLState]() + private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" - private def buildQueryInfo(): QueryInfo = { + private def buildQueryInfo(): QueryInfo = { val taskCtxOpt = Option(TaskContext.get()) val (queryId, batchId) = if (!isStreaming) { (BATCH_QUERY_ID, 0L) @@ -103,22 +113,48 @@ class StatefulProcessorHandleImpl( private var currState: StatefulProcessorHandleState = CREATED - private def verify(condition: => Boolean, msg: String): Unit = { - if (!condition) { - throw new IllegalStateException(msg) - } - } - def setHandleState(newState: StatefulProcessorHandleState): Unit = { currState = newState } def getHandleState: StatefulProcessorHandleState = currState - override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { + override def getValueState[T]( + stateName: String, + valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state") - val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) - resultState + + if (ttlMode == TTLMode.NoTTL()) { + new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) + } else { + val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName, + keyEncoder, valEncoder, ttlMode, batchTimestampMs, eventTimeWatermarkMs) + + val ttlState = valueStateWithTTL.ttlState + ttlState.setStateVariable(valueStateWithTTL) + ttlStates.add(ttlState) + + valueStateWithTTL + } + } + + override def getListState[T]( + stateName: String, + valEncoder: Encoder[T]): ListState[T] = { + verifyStateVarOperations("get_list_state") + + if (ttlMode == TTLMode.NoTTL()) { + new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) + } else { + val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, + keyEncoder, valEncoder, ttlMode, batchTimestampMs, eventTimeWatermarkMs) + + val ttlState = listStateWithTTL.ttlState + ttlState.setStateVariable(listStateWithTTL) + ttlStates.add(ttlState) + + listStateWithTTL + } } override def getQueryInfo(): QueryInfo = currQueryInfo @@ -185,6 +221,16 @@ class StatefulProcessorHandleImpl( timerState.listTimers() } + /** + * Performs the user state cleanup based on assigned TTl values. Any state + * which is expired will be cleaned up from StateStore. + */ + def doTtlCleanup(): Unit = { + ttlStates.forEach { s => + s.clearExpiredState() + } + } + /** * Function to delete and purge state variable if defined previously * @@ -195,12 +241,6 @@ class StatefulProcessorHandleImpl( store.removeColFamilyIfExists(stateName) } - override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { - verifyStateVarOperations("get_list_state") - val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) - resultState - } - override def getMapState[K, V]( stateName: String, userKeyEnc: Encoder[K], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala new file mode 100644 index 000000000000..b60997bc801f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.time.Duration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.streaming.TTLMode +import org.apache.spark.sql.types.{BinaryType, DataType, LongType, NullType, StructField, StructType} + +object StateTTLSchema { + val KEY_ROW_SCHEMA: StructType = new StructType() + .add("expirationMs", LongType) + .add("groupingKey", BinaryType) + val VALUE_ROW_SCHEMA: StructType = + StructType(Array(StructField("__dummy__", NullType))) +} + +/** + * Encapsulates the ttl row information stored in [[SingleKeyTTLStateImpl]]. + * + * @param groupingKey grouping key for which ttl is set + * @param expirationMs expiration time for the grouping key + */ +case class SingleKeyTTLRow( + groupingKey: Array[Byte], + expirationMs: Long) + +/** + * Represents a State variable which supports TTL. + */ +trait StateVariableWithTTLSupport { + + /** + * Clears the user state associated with this grouping key + * if it has expired. This function is called by Spark to perform + * cleanup at the end of transformWithState processing. + * + * Spark uses a secondary index to determine if the user state for + * this grouping key has expired. However, its possible that the user + * has updated the TTL and secondary index is out of date. Implementations + * must validate that the user State has actually expired before cleanup based + * on their own State data. + * + * @param groupingKey grouping key for which cleanup should be performed. + */ + def clearIfExpired(groupingKey: Array[Byte]): Unit +} + +/** + * Represents the underlying state for secondary TTL Index for a user defined + * state variable. + * + * This state allows Spark to query ttl values based on expiration time + * allowing efficient ttl cleanup. + */ +trait TTLState { + + /** + * Perform the user state clean up based on ttl values stored in + * this state. NOTE that its not safe to call this operation concurrently + * when the user can also modify the underlying State. Cleanup should be initiated + * after arbitrary state operations are completed by the user. + */ + def clearExpiredState(): Unit +} + +/** + * Manages the ttl information for user state keyed with a single key (grouping key). + */ +class SingleKeyTTLStateImpl( + ttlMode: TTLMode, + stateName: String, + store: StateStore, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]) + extends TTLState + with Logging { + + import org.apache.spark.sql.execution.streaming.StateTTLSchema._ + + private val ttlColumnFamilyName = s"_ttl_$stateName" + private val ttlKeyEncoder = UnsafeProjection.create(KEY_ROW_SCHEMA) + private var state: StateVariableWithTTLSupport = _ + + // empty row used for values + private val EMPTY_ROW = + UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) + + // TODO: use range scan once Range Scan PR is merged for StateStore + store.createColFamilyIfAbsent(ttlColumnFamilyName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), isInternal = true) + + def upsertTTLForStateKey( + expirationMs: Long, + groupingKey: Array[Byte]): Unit = { + val encodedTtlKey = ttlKeyEncoder(InternalRow(expirationMs, groupingKey)) + store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName) + } + + override def clearExpiredState(): Unit = { + store.iterator(ttlColumnFamilyName).foreach { kv => + val expirationMs = kv.key.getLong(0) + val isExpired = StateTTL.isExpired(ttlMode, expirationMs, + batchTimestampMs, eventTimeWatermarkMs) + + if (isExpired) { + val groupingKey = kv.key.getBinary(1) + state.clearIfExpired(groupingKey) + + store.remove(kv.key, ttlColumnFamilyName) + } + } + } + + private[sql] def setStateVariable( + state: StateVariableWithTTLSupport): Unit = { + this.state = state + } + + private[sql] def iterator(): Iterator[SingleKeyTTLRow] = { + val ttlIterator = store.iterator(ttlColumnFamilyName) + + new Iterator[SingleKeyTTLRow] { + override def hasNext: Boolean = ttlIterator.hasNext + + override def next(): SingleKeyTTLRow = { + val kv = ttlIterator.next() + SingleKeyTTLRow( + expirationMs = kv.key.getLong(0), + groupingKey = kv.key.getBinary(1) + ) + } + } + } +} + +/** + * Helper methods for user State TTL. + */ +object StateTTL extends Logging { + def calculateExpirationTimeForDuration( + ttlMode: TTLMode, + ttlDuration: Duration, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Long = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get + ttlDuration.toMillis + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get + ttlDuration.toMillis + } else { + throw new IllegalStateException(s"cannot calculate expiration time for" + + s" unknown ttl Mode $ttlMode") + } + } + + def isExpired( + ttlMode: TTLMode, + expirationMs: Long, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Boolean = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + logError(s"### batchTimestampMs: ${batchTimestampMs.get}, expirationMs: ${expirationMs}") + batchTimestampMs.get >= expirationMs + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get >= expirationMs + } else { + throw new IllegalStateException(s"cannot evaluate expiry condition for" + + s" unknown ttl Mode $ttlMode") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala index af321eecb4db..8d410b677c84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -78,25 +78,25 @@ class TimerStateImpl( private val secIndexKeyEncoder = UnsafeProjection.create(keySchemaForSecIndex) - val timerCFName = if (timeoutMode == TimeoutMode.ProcessingTime) { + private val timerCFName = if (timeoutMode == TimeoutMode.ProcessingTime) { TimerStateUtils.PROC_TIMERS_STATE_NAME } else { TimerStateUtils.EVENT_TIMERS_STATE_NAME } - val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF + private val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF store.createColFamilyIfAbsent(keyToTsCFName, schemaForKeyRow, schemaForValueRow, PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1), useMultipleValuesPerKey = false, isInternal = true) - val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF + private val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex, schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, 1), useMultipleValuesPerKey = false, isInternal = true) private def getGroupingKey(cfName: String): Any = { val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption - if (!keyOption.isDefined) { + if (keyOption.isEmpty) { throw StateStoreErrors.implicitKeyNotFound(cfName) } keyOption.get diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 36b957f9d430..2390e19384f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -28,10 +28,10 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.streaming._ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Utils} /** @@ -42,6 +42,7 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Uti * @param groupingAttributes used to group the data * @param dataAttributes used to read the data * @param statefulProcessor processor methods called on underlying data + * @param ttlMode defines the ttl Mode for user state * @param timeoutMode defines the timeout mode * @param outputMode defines the output mode for the statefulProcessor * @param keyEncoder expression encoder for the key type @@ -58,6 +59,7 @@ case class TransformWithStateExec( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], @@ -102,10 +104,6 @@ case class TransformWithStateExec( override def keyExpressions: Seq[Attribute] = groupingAttributes - protected val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) - - protected val schemaForValueRow: StructType = new StructType().add("value", BinaryType) - /** * Distribute by grouping attributes - We need the underlying data and the initial state data * to have the same grouping so that the data are co-located on the same task. @@ -283,6 +281,8 @@ case class TransformWithStateExec( allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) commitTimeMs += timeTakenMs { if (isStreaming) { + // clean up any expired user state + processorHandle.doTtlCleanup() store.commit() } else { store.abort() @@ -299,19 +299,8 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver - timeoutMode match { - case ProcessingTime => - if (batchTimestampMs.isEmpty) { - StateStoreErrors.missingTimeoutValues(timeoutMode.toString) - } - - case EventTime => - if (eventTimeWatermarkForEviction.isEmpty) { - StateStoreErrors.missingTimeoutValues(timeoutMode.toString) - } - - case _ => - } + validateTTLMode() + validateTimeoutMode() if (hasInitialState) { val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf) @@ -331,9 +320,9 @@ case class TransformWithStateExec( val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId) val store = StateStore.get( storeProviderId = storeProviderId, - keySchema = schemaForKeyRow, - valueSchema = schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), version = stateInfo.get.storeVersion, useColumnFamilies = true, storeConf = storeConf, @@ -351,9 +340,9 @@ case class TransformWithStateExec( if (isStreaming) { child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, - schemaForKeyRow, - schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), session.sqlContext.sessionState, Some(session.sqlContext.streams.stateStoreCoordinator), useColumnFamilies = true @@ -401,9 +390,9 @@ case class TransformWithStateExec( // Create StateStoreProvider for this partition val stateStoreProvider = StateStoreProvider.createAndInit( providerId, - schemaForKeyRow, - schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useColumnFamilies = true, storeConf = storeConf, hadoopConf = hadoopConfBroadcast.value.value, @@ -426,10 +415,11 @@ case class TransformWithStateExec( private def processData(store: StateStore, singleIterator: Iterator[InternalRow]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl( - store, getStateInfo.queryRunId, keyEncoder, timeoutMode, isStreaming) + store, getStateInfo.queryRunId, keyEncoder, ttlMode, timeoutMode, + isStreaming, batchTimestampMs, eventTimeWatermarkForEviction) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) - statefulProcessor.init(outputMode, timeoutMode) + statefulProcessor.init(outputMode, timeoutMode, ttlMode) processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) processDataWithPartition(singleIterator, store, processorHandle) } @@ -440,10 +430,10 @@ case class TransformWithStateExec( initStateIterator: Iterator[InternalRow]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, - keyEncoder, timeoutMode, isStreaming) + keyEncoder, ttlMode, timeoutMode, isStreaming) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) - statefulProcessor.init(outputMode, timeoutMode) + statefulProcessor.init(outputMode, timeoutMode, ttlMode) processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) // Check if is first batch @@ -461,9 +451,41 @@ case class TransformWithStateExec( processDataWithPartition(childDataIterator, store, processorHandle) } + + private def validateTimeoutMode(): Unit = { + timeoutMode match { + case ProcessingTime => + if (batchTimestampMs.isEmpty) { + StateStoreErrors.missingTimeoutValues(timeoutMode.toString) + } + + case EventTime => + if (eventTimeWatermarkForEviction.isEmpty) { + StateStoreErrors.missingTimeoutValues(timeoutMode.toString) + } + + case _ => + } + } + + private def validateTTLMode(): Unit = { + ttlMode match { + case ProcessingTimeTTL => + if (batchTimestampMs.isEmpty) { + StateStoreErrors.missingTTLValues(timeoutMode.toString) + } + + case EventTimeTTL => + if (eventTimeWatermarkForEviction.isEmpty) { + StateStoreErrors.missingTTLValues(timeoutMode.toString) + } + + case _ => + } + } } -// scalastyle:off +// scalastyle:off argcount object TransformWithStateExec { // Plan logical transformWithState for batch queries @@ -473,6 +495,7 @@ object TransformWithStateExec { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], @@ -498,6 +521,7 @@ object TransformWithStateExec { groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, keyEncoder, @@ -515,4 +539,5 @@ object TransformWithStateExec { initialState) } } -// scalastyle:on +// scalastyle:on argcount + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 08876ca3032e..fa250d659137 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -16,12 +16,13 @@ */ package org.apache.spark.sql.execution.streaming +import java.time.Duration + import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.ValueState /** @@ -29,7 +30,7 @@ import org.apache.spark.sql.streaming.ValueState * variables used in the streaming transformWithState operator. * @param store - reference to the StateStore instance to be used for storing state * @param stateName - name of logical state partition - * @param keyEnc - Spark SQL encoder for key + * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value * @tparam S - data type of object that will be stored */ @@ -37,18 +38,24 @@ class ValueStateImpl[S]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], - valEncoder: Encoder[S]) extends ValueState[S] with Logging { + valEncoder: Encoder[S]) + extends ValueState[S] + with Logging { private val keySerializer = keyExprEnc.createSerializer() - private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) + private[sql] var ttlState: Option[SingleKeyTTLStateImpl] = None + + initialize() - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + private def initialize(): Unit = { + store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + } /** Function to check if state exists. Returns true if present and false otherwise */ override def exists(): Boolean = { - getImpl() != null + get() != null } /** Function to return Option of value if exists and None otherwise */ @@ -58,7 +65,9 @@ class ValueStateImpl[S]( /** Function to return associated value with key if exists and null otherwise */ override def get(): S = { - val retRow = getImpl() + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + if (retRow != null) { stateTypesEncoder.decodeValue(retRow) } else { @@ -66,14 +75,34 @@ class ValueStateImpl[S]( } } - private def getImpl(): UnsafeRow = { - store.get(stateTypesEncoder.encodeGroupingKey(), stateName) + /** Function to update and overwrite state associated with given key */ + override def update( + newState: S, + ttlDuration: Duration = Duration.ZERO): Unit = { + + if (ttlDuration != Duration.ZERO) { + throw StateStoreErrors.cannotProvideTTLDurationForNoTTLMode("update", stateName) + } + + val encodedValue = stateTypesEncoder.encodeValue(newState) + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey), + encodedValue, stateName) } /** Function to update and overwrite state associated with given key */ - override def update(newState: S): Unit = { - store.put(stateTypesEncoder.encodeGroupingKey(), - stateTypesEncoder.encodeValue(newState), stateName) + override def update( + newState: S, + expirationMs: Long): Unit = { + + if (expirationMs != -1) { + throw StateStoreErrors.cannotProvideTTLDurationForNoTTLMode("update", stateName) + } + + val encodedValue = stateTypesEncoder.encodeValue(newState) + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey), + encodedValue, stateName) } /** Function to remove state for given key */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala new file mode 100644 index 000000000000..92db728788e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.time.Duration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.streaming.{TTLMode, ValueState} + +/** + * Class that provides a concrete implementation for a single value state associated with state + * variables (with ttl expiration support) used in the streaming transformWithState operator. + * + * @param store - reference to the StateStore instance to be used for storing state + * @param stateName - name of logical state partition + * @param keyExprEnc - Spark SQL encoder for key + * @param valEncoder - Spark SQL encoder for value + * @param ttlMode - TTL Mode for values stored in this state + * @param batchTimestampMs - processing timestamp of the current batch. + * @param eventTimeWatermarkMs - event time watermark for streaming query + * (same as watermark for state eviction) + * @tparam S - data type of object that will be stored + */ +class ValueStateImplWithTTL[S]( + store: StateStore, + stateName: String, + keyExprEnc: ExpressionEncoder[Any], + valEncoder: Encoder[S], + ttlMode: TTLMode, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]) + extends ValueState[S] + with Logging + with StateVariableWithTTLSupport { + + private val keySerializer = keyExprEnc.createSerializer() + private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, + stateName, hasTtl = true) + private[sql] var ttlState: SingleKeyTTLStateImpl = _ + + initialize() + + private def initialize(): Unit = { + assert(ttlMode != TTLMode.NoTTL()) + + store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + + ttlState = new SingleKeyTTLStateImpl(ttlMode, stateName, store, + batchTimestampMs, eventTimeWatermarkMs) + } + + /** Function to check if state exists. Returns true if present and false otherwise */ + override def exists(): Boolean = { + get() != null + } + + /** Function to return Option of value if exists and None otherwise */ + override def getOption(): Option[S] = { + Option(get()) + } + + /** Function to return associated value with key if exists and null otherwise */ + override def get(): S = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + val resState = stateTypesEncoder.decodeValue(retRow) + + if (!isExpired(retRow)) { + resState + } else { + null.asInstanceOf[S] + } + } else { + null.asInstanceOf[S] + } + } + + /** Function to update and overwrite state associated with given key */ + override def update( + newState: S, + ttlDuration: Duration = Duration.ZERO): Unit = { + + if (ttlMode == TTLMode.EventTimeTTL() && ttlDuration != Duration.ZERO) { + throw StateStoreErrors.cannotProvideTTLDurationForEventTimeTTLMode( + "update", stateName) + } + + val expirationMs = + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } else { + -1 + } + + logError(s"### valueState expirationMs: ${expirationMs}") + val encodedValue = stateTypesEncoder.encodeValue(newState, expirationMs) + + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey), + encodedValue, stateName) + + ttlState.upsertTTLForStateKey(expirationMs, serializedGroupingKey) + } + + override def update( + newState: S, + expirationMs: Long): Unit = { + + val encodedValue = stateTypesEncoder.encodeValue(newState, expirationMs) + + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey), + encodedValue, stateName) + + ttlState.upsertTTLForStateKey(expirationMs, serializedGroupingKey) + } + + /** Function to remove state for given key */ + override def clear(): Unit = { + store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + } + + def clearIfExpired(groupingKey: Array[Byte]): Unit = { + val encodedGroupingKey = stateTypesEncoder.encodeSerializedGroupingKey(groupingKey) + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + if (isExpired(retRow)) { + store.remove(encodedGroupingKey, stateName) + } + } + } + + private def isExpired(valueRow: UnsafeRow): Boolean = { + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow) + val isExpired = expirationMs.map( + StateTTL.isExpired(ttlMode, _, batchTimestampMs, eventTimeWatermarkMs)) + + isExpired.isDefined && isExpired.get + } + + /* + * Internal methods to probe state for testing. The below methods exist for unit tests + * to read the state ttl values, and ensure that values are persisted correctly in + * the underlying state store. + */ + + /** + * Retrieves the value from State even if its expired. This method is used + * in tests to read the state store value, and ensure if its cleaned up at the + * end of the micro-batch. + */ + private[sql] def getWithoutEnforcingTTL(): Option[S] = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + val resState = stateTypesEncoder.decodeValue(retRow) + Some(resState) + } else { + None + } + } + + /** + * Read the ttl value associated with the grouping key. + */ + private[sql] def getTTLValue(): Option[Long] = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + stateTypesEncoder.decodeTtlExpirationMs(retRow) + } else { + None + } + } + + /** + * Get all ttl values stored in ttl state for current implicit + * grouping key. + */ + private[sql] def getValuesInTTLState(): Iterator[Long] = { + val ttlIterator = ttlState.iterator() + val implicitGroupingKey = stateTypesEncoder.serializeGroupingKey() + var nextValue: Option[Long] = None + + new Iterator[Long] { + override def hasNext: Boolean = { + while (nextValue.isEmpty && ttlIterator.hasNext) { + val nextTtlValue = ttlIterator.next() + val groupingKey = nextTtlValue.groupingKey + + if (groupingKey sameElements implicitGroupingKey) { + nextValue = Some(nextTtlValue.expirationMs) + } + } + + nextValue.isDefined + } + + override def next(): Long = { + val result = nextValue.get + nextValue = None + + result + } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 2f72cbb0b0fc..6ee7382d7bd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -39,6 +39,13 @@ object StateStoreErrors { ) } + def missingTTLValues(ttlMode: String): SparkException = { + SparkException.internalError( + msg = s"Failed to find timeout values for ttlMode=$ttlMode", + category = "TWS" + ) + } + def unsupportedOperationOnMissingColumnFamily(operationName: String, colFamilyName: String): StateStoreUnsupportedOperationOnMissingColumnFamily = { new StateStoreUnsupportedOperationOnMissingColumnFamily(operationName, colFamilyName) @@ -117,6 +124,16 @@ object StateStoreErrors { StatefulProcessorCannotReInitializeState = { new StatefulProcessorCannotReInitializeState(groupingKey) } + + def cannotProvideTTLDurationForNoTTLMode(operationType: String, + stateName: String): StatefulProcessorCannotAssignTTLInNoTTLMode = { + new StatefulProcessorCannotAssignTTLInNoTTLMode(operationType, stateName) + } + + def cannotProvideTTLDurationForEventTimeTTLMode(operationType: String, + stateName: String): StatefulProcessorCannotUseTTLDurationInEventTimeTTLMode = { + new StatefulProcessorCannotUseTTLDurationInEventTimeTTLMode(operationType, stateName) + } } class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String) @@ -192,3 +209,17 @@ class StateStoreNullTypeOrderingColsNotSupported(fieldName: String, index: Strin extends SparkUnsupportedOperationException( errorClass = "STATE_STORE_NULL_TYPE_ORDERING_COLS_NOT_SUPPORTED", messageParameters = Map("fieldName" -> fieldName, "index" -> index)) + +class StatefulProcessorCannotAssignTTLInNoTTLMode( + operationType: String, + stateName: String) + extends SparkUnsupportedOperationException( + errorClass = "STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE", + messageParameters = Map("operationType" -> operationType, "stateName" -> stateName)) + +class StatefulProcessorCannotUseTTLDurationInEventTimeTTLMode( + operationType: String, + stateName: String) + extends SparkUnsupportedOperationException( + errorClass = "STATEFUL_PROCESSOR_CANNOT_USE_TTL_DURATION_IN_EVENT_TIME_TTL_MODE", + messageParameters = Map("operationType" -> operationType, "stateName" -> stateName)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala index e895e475b74d..51cfc1548b39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} -import org.apache.spark.sql.streaming.{ListState, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, TimeoutMode, TTLMode, ValueState} /** * Class that adds unit tests for ListState types used in arbitrary stateful @@ -37,7 +37,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val listState: ListState[Long] = handle.getListState[Long]("listState", Encoders.scalaLong) @@ -47,7 +48,7 @@ class ListStateSuite extends StateVariableSuiteBase { } checkError( - exception = e.asInstanceOf[SparkIllegalArgumentException], + exception = e, errorClass = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE", sqlState = Some("42601"), parameters = Map("stateName" -> "listState") @@ -70,7 +71,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ListState[Long] = handle.getListState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") @@ -98,7 +100,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: ListState[Long] = handle.getListState[Long]("testState1", Encoders.scalaLong) val testState2: ListState[Long] = handle.getListState[Long]("testState2", Encoders.scalaLong) @@ -136,7 +139,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val listState1: ListState[Long] = handle.getListState[Long]("listState1", Encoders.scalaLong) val listState2: ListState[Long] = handle.getListState[Long]("listState2", Encoders.scalaLong) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala index ce72061d39ea..7fa41b12795e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala @@ -22,7 +22,7 @@ import java.util.UUID import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} -import org.apache.spark.sql.streaming.{ListState, MapState, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, MapState, TimeoutMode, TTLMode, ValueState} import org.apache.spark.sql.types.{BinaryType, StructType} /** @@ -39,7 +39,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: MapState[String, Double] = handle.getMapState[String, Double]("testState", Encoders.STRING, Encoders.scalaDouble) @@ -73,7 +74,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: MapState[Long, Double] = handle.getMapState[Long, Double]("testState1", Encoders.scalaLong, Encoders.scalaDouble) @@ -112,7 +114,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val mapTestState1: MapState[String, Int] = handle.getMapState[String, Int]("mapTestState1", Encoders.STRING, Encoders.scalaInt) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala index 662a5dbfaac4..aec828459fce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} -import org.apache.spark.sql.streaming.TimeoutMode +import org.apache.spark.sql.streaming.{TimeoutMode, TTLMode} + /** * Class that adds tests to verify operations based on stateful processor handle @@ -48,7 +49,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) assert(handle.getHandleState === StatefulProcessorHandleState.CREATED) handle.getValueState[Long]("testState", Encoders.scalaLong) } @@ -89,7 +90,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) Seq(StatefulProcessorHandleState.INITIALIZED, StatefulProcessorHandleState.DATA_PROCESSED, @@ -107,7 +108,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, TimeoutMode.NoTimeouts()) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val ex = intercept[SparkUnsupportedOperationException] { handle.registerTimer(10000L) } @@ -143,7 +144,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) handle.setHandleState(StatefulProcessorHandleState.INITIALIZED) assert(handle.getHandleState === StatefulProcessorHandleState.INITIALIZED) @@ -164,7 +165,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) handle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED) assert(handle.getHandleState === StatefulProcessorHandleState.DATA_PROCESSED) @@ -204,7 +205,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) Seq(StatefulProcessorHandleState.CREATED, StatefulProcessorHandleState.TIMER_PROCESSED, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 8668b58672c7..8063c2cdb155 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{TimeoutMode, TTLMode, ValueState} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -48,7 +48,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val stateName = "testState" val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -78,7 +79,7 @@ class ValueStateSuite extends StateVariableSuiteBase { testState.update(123) } checkError( - ex.asInstanceOf[SparkException], + ex1.asInstanceOf[SparkException], errorClass = "INTERNAL_ERROR_TWS", parameters = Map( "message" -> s"Implicit key not found in state store for stateName=$stateName" @@ -92,7 +93,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") @@ -118,7 +120,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: ValueState[Long] = handle.getValueState[Long]( "testState1", Encoders.scalaLong) @@ -164,7 +167,7 @@ class ValueStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], - TimeoutMode.NoTimeouts()) + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val cfName = "_testState" val ex = intercept[SparkUnsupportedOperationException] { @@ -204,7 +207,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Double] = handle.getValueState[Double]("testState", Encoders.scalaDouble) @@ -230,7 +234,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -256,7 +261,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[TestClass] = handle.getValueState[TestClass]("testState", Encoders.product[TestClass]) @@ -282,7 +288,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[POJOTestClass] = handle.getValueState[POJOTestClass]("testState", Encoders.bean(classOf[POJOTestClass])) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 95ab34d40131..5ccc14ab8a77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -32,7 +32,8 @@ class TestListStateProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _listState = getHandle.getListState("testListState", Encoders.STRING) } @@ -89,7 +90,8 @@ class ToggleSaveAndEmitProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _listState = getHandle.getListState("testListState", Encoders.STRING) _valueState = getHandle.getValueState("testValueState", Encoders.scalaBoolean) } @@ -140,6 +142,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update()) ( @@ -160,6 +163,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -180,6 +184,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -200,6 +205,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -220,6 +226,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -240,6 +247,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -260,6 +268,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update()) ( @@ -312,6 +321,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x) .transformWithState(new ToggleSaveAndEmitProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index db8cb8b810af..d32b9687d95f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -32,7 +32,8 @@ class TestMapStateProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _mapState = getHandle.getMapState("sessionState", Encoders.STRING, Encoders.STRING) } @@ -95,6 +96,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) @@ -121,6 +123,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -145,6 +148,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -168,6 +172,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) testStream(result, OutputMode.Append())( // Test exists() @@ -222,6 +227,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) val df = result.toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala index 9f2e2c2d9f02..031a515d3b79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala @@ -34,7 +34,10 @@ abstract class StatefulProcessorWithInitialStateTestClass[V] @transient var _listState: ListState[Double] = _ @transient var _mapState: MapState[Double, Int] = _ - override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _valState = getHandle.getValueState[Double]("testValueInit", Encoders.scalaDouble) _listState = getHandle.getListState[Double]("testListInit", Encoders.scalaDouble) _mapState = getHandle.getMapState[Double, Int]( @@ -154,7 +157,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest InputRowForInitialState("init_2", 100.0, List(100.0), Map(100.0 -> 1))) .toDS().groupByKey(x => x.key).mapValues(x => x) val query = kvDataSet.transformWithState(new InitialStateInMemoryTestClass(), - TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf) + TTLMode.NoTTL(), TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf) testStream(query, OutputMode.Update())( // non-exist key test @@ -232,7 +235,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest val query = inputData.toDS() .groupByKey(x => x.key) .transformWithState(new AccumulateStatefulProcessorWithInitState(), - TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf + TTLMode.NoTTL(), TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf ) testStream(query, OutputMode.Update())( AddData(inputData, InitInputRow("init_1", "add", 50.0)), @@ -252,6 +255,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest val result = inputData.toDS() .groupByKey(x => x.key) .transformWithState(new AccumulateStatefulProcessorWithInitState(), + TTLMode.NoTTL(), TimeoutMode.NoTimeouts(), OutputMode.Append(), createInitialDfForTest) @@ -270,6 +274,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest val query = inputData.toDS() .groupByKey(x => x.key) .transformWithState(new AccumulateStatefulProcessorWithInitState(), + TTLMode.NoTTL(), TimeoutMode.NoTimeouts(), OutputMode.Append(), initDf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 24e68e3db9d8..3f21c50abae4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -40,7 +40,8 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) } @@ -103,8 +104,9 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode) : Unit = { - super.init(outputMode, timeoutMode) + timeoutMode: TimeoutMode, + ttlMode: TTLMode) : Unit = { + super.init(outputMode, timeoutMode, ttlMode) _timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong) } @@ -186,7 +188,8 @@ class MaxEventTimeStatefulProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _maxEventTimeState = getHandle.getValueState[Long]("maxEventTimeState", Encoders.scalaLong) _timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong) @@ -227,10 +230,12 @@ class RunningCountMostRecentStatefulProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) _mostRecent = getHandle.getValueState[String]("mostRecent", Encoders.STRING) } + override def handleInputRows( key: String, inputRows: Iterator[(String, String)], @@ -256,7 +261,8 @@ class MostRecentStatefulProcessorWithDeletion override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { getHandle.deleteIfExists("countState") _mostRecent = getHandle.getValueState[String]("mostRecent", Encoders.STRING) } @@ -310,6 +316,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessorWithError(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -331,6 +338,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -361,6 +369,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -404,6 +413,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new RunningCountStatefulProcessorWithProcTimeTimerUpdates(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -440,6 +450,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new RunningCountStatefulProcessorWithMultipleTimers(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -475,6 +486,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new MaxEventTimeStatefulProcessor(), TimeoutMode.EventTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -516,6 +528,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) val df = result.toDF() @@ -534,12 +547,14 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x._1) .transformWithState(new RunningCountMostRecentStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) val stream2 = inputData.toDS() .groupByKey(x => x._1) .transformWithState(new MostRecentStatefulProcessorWithDeletion(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(stream1, OutputMode.Update())( @@ -572,6 +587,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -605,6 +621,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -638,6 +655,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -668,6 +686,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) } @@ -760,6 +779,7 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest { .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -778,7 +798,7 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest { val result = inputData.toDS() .groupByKey(x => x.key) .transformWithState(new AccumulateStatefulProcessorWithInitState(), - TimeoutMode.NoTimeouts(), OutputMode.Append(), initDf + TTLMode.NoTTL(), TimeoutMode.NoTimeouts(), OutputMode.Append(), initDf ) testStream(result, OutputMode.Update())( AddData(inputData, InitInputRow("a", "add", -1.0)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala new file mode 100644 index 000000000000..7ea015b086ef --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala @@ -0,0 +1,710 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.sql.Timestamp +import java.time.Duration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.execution.streaming.{ListStateImplWithTTL, MemoryStream, ValueStateImplWithTTL} +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StreamManualClock + +case class InputEvent( + key: String, + action: String, + value: Int, + ttl: Duration, + eventTime: Timestamp = null, + eventTimeTtl: Timestamp = null) + +case class OutputEvent( + key: String, + value: Int, + isTTLValue: Boolean, + ttlValue: Long) + +object TTLInputProcessFunction extends Logging { + def processRow( + ttlMode: TTLMode, + row: InputEvent, + valueState: ValueStateImplWithTTL[Int]): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + val key = row.key + if (row.action == "get") { + val currState = valueState.getOption() + if (currState.isDefined) { + results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_without_enforcing_ttl") { + val currState = valueState.getWithoutEnforcingTTL() + if (currState.isDefined) { + results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_ttl_value_from_state") { + val ttlExpiration = valueState.getTTLValue() + if (ttlExpiration.isDefined) { + results = OutputEvent(key, -1, isTTLValue = true, ttlExpiration.get) :: results + } + } else if (row.action == "put") { + if (ttlMode == TTLMode.EventTimeTTL() && row.eventTimeTtl != null) { + valueState.update(row.value, row.eventTimeTtl.getTime) + } else if (ttlMode == TTLMode.EventTimeTTL()) { + valueState.update(row.value) + } else { + valueState.update(row.value, row.ttl) + } + } else if (row.action == "get_values_in_ttl_state") { + val ttlValues = valueState.getValuesInTTLState() + ttlValues.foreach { v => + results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: results + } + } + + results.iterator + } + + def processRow( + ttlMode: TTLMode, + row: InputEvent, + listState: ListStateImplWithTTL[Int]): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + val key = row.key + if (row.action == "get") { + logError(s"### get") + val currState = listState.get() + currState.foreach { v => + results = OutputEvent(key, v, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_without_enforcing_ttl") { + logError(s"### get without enforcing ttl") + val currState = listState.getWithoutEnforcingTTL() + currState.foreach { v => + results = OutputEvent(key, v, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_ttl_value_from_state") { + logError(s"### get ttl value from state") + val ttlExpiration = listState.getTTLValues() + ttlExpiration.filter(_.isDefined).foreach { v => + results = OutputEvent(key, -1, isTTLValue = false, v.get) :: results + } + } else if (row.action == "put") { + logError(s"### put") + if (ttlMode == TTLMode.EventTimeTTL() && row.eventTimeTtl != null) { + listState.put(Array(row.value), row.eventTimeTtl.getTime) + } else if (ttlMode == TTLMode.EventTimeTTL()) { + listState.put(Array(row.value)) + } else { + listState.put(Array(row.value), row.ttl) + } + } else if (row.action == "append") { + logError(s"### append") + if (ttlMode == TTLMode.EventTimeTTL() && row.eventTimeTtl != null) { + listState.appendValue(row.value, row.eventTimeTtl.getTime) + } else if (ttlMode == TTLMode.EventTimeTTL()) { + listState.appendValue(row.value) + } else { + listState.appendValue(row.value, row.ttl) + } + } else if (row.action == "get_values_in_ttl_state") { + logError(s"### get values in ttl state") + val ttlValues = listState.getValuesInTTLState() + ttlValues.foreach { v => + results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: results + } + } + + results.iterator + } +} + +class ValueStateTTLProcessor + extends StatefulProcessor[String, InputEvent, OutputEvent] + with Logging { + + @transient private var _valueState: ValueStateImplWithTTL[Int] = _ + @transient private var _ttlMode: TTLMode = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { + _valueState = getHandle + .getValueState("valueState", Encoders.scalaInt) + .asInstanceOf[ValueStateImplWithTTL[Int]] + _ttlMode = ttlMode + } + + override def handleInputRows( + key: String, + inputRows: Iterator[InputEvent], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + + for (row <- inputRows) { + val resultIter = TTLInputProcessFunction.processRow(_ttlMode, row, _valueState) + resultIter.foreach { r => + results = r :: results + } + } + + results.iterator + } +} + +case class MultipleValueStatesTTLProcessor( + ttlKey: String, + noTtlKey: String) + extends StatefulProcessor[String, InputEvent, OutputEvent] + with Logging { + + @transient private var _valueStateWithTTL: ValueStateImplWithTTL[Int] = _ + @transient private var _valueStateWithoutTTL: ValueStateImplWithTTL[Int] = _ + @transient private var _ttlMode: TTLMode = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { + _valueStateWithTTL = getHandle + .getValueState("valueState", Encoders.scalaInt) + .asInstanceOf[ValueStateImplWithTTL[Int]] + _valueStateWithoutTTL = getHandle + .getValueState("valueState", Encoders.scalaInt) + .asInstanceOf[ValueStateImplWithTTL[Int]] + _ttlMode = ttlMode + } + + override def handleInputRows( + key: String, + inputRows: Iterator[InputEvent], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + val state = if (key == ttlKey) { + _valueStateWithTTL + } else { + _valueStateWithoutTTL + } + + for (row <- inputRows) { + val resultIterator = TTLInputProcessFunction.processRow(_ttlMode, row, state) + resultIterator.foreach { r => + results = r :: results + } + } + results.iterator + } +} + +class ListStateTTLProcessor + extends StatefulProcessor[String, InputEvent, OutputEvent] + with Logging { + + @transient private var _listState: ListStateImplWithTTL[Int] = _ + @transient private var _ttlMode: TTLMode = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { + _listState = getHandle + .getListState("listState", Encoders.scalaInt) + .asInstanceOf[ListStateImplWithTTL[Int]] + _ttlMode = ttlMode + } + + override def handleInputRows( + key: String, + inputRows: Iterator[InputEvent], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + + for (row <- inputRows) { + val resultIter = TTLInputProcessFunction.processRow(_ttlMode, row, _listState) + resultIter.foreach { r => + results = r :: results + } + } + + results.iterator + } +} + +abstract class TransformWithStateTTLTest + extends StreamTest { + import testImplicits._ + + def getProcessor(): StatefulProcessor[String, InputEvent, OutputEvent] + + test("validate state is evicted at ttl expiry - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + getProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock so that state expires + AdvanceManualClock(60 * 1000), + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + // validate expired value is not returned + CheckNewAnswer(), + // ensure this state does not exist any longer in State + AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + test("validate ttl update updates the expiration timestamp - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + getProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock and update expiration time + AdvanceManualClock(30 * 1000), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent("k1", "get", -1, null)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + // validate value is not expired + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate ttl value is updated in the state + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)), + // validate ttl state has both ttl values present + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000), + OutputEvent("k1", -1, isTTLValue = true, 95000) + ), + // advance clock after older expiration value + AdvanceManualClock(30 * 1000), + // ensure unexpired value is still present in the state + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate that the older expiration value is removed from ttl state + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)) + ) + } + } + + test("validate ttl removal keeps value in state - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + getProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock and update state to remove ttl + AdvanceManualClock(30 * 1000), + AddData(inputStream, InputEvent("k1", "put", 1, null)), + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + // validate value is not expired + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // validate ttl state still has old ttl value present + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock after older expiration value + AdvanceManualClock(30 * 1000), + // ensure unexpired value is still present in the state + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate that the older expiration value is removed from ttl state + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + test("validate state is evicted at ttl expiry - event time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .withWatermark("eventTime", "1 second") + .groupByKey(x => x.key) + .transformWithState( + getProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.EventTimeTTL()) + + val eventTime1 = Timestamp.valueOf("2024-01-01 00:00:00") + val eventTime2 = Timestamp.valueOf("2024-01-01 00:02:00") + val ttlExpiration = Timestamp.valueOf("2024-01-01 00:03:00") + val ttlExpirationMs = ttlExpiration.getTime + val eventTime3 = Timestamp.valueOf("2024-01-01 00:05:00") + + testStream(result)( + AddData(inputStream, + InputEvent("k1", "put", 1, null, eventTime1, ttlExpiration)), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, + InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // increment event time so that key k1 expires + AddData(inputStream, InputEvent("k2", "put", 1, null, eventTime3)), + CheckNewAnswer(), + // validate that k1 has expired + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime3)), + CheckNewAnswer(), + // ensure this state does not exist any longer in State + AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", -1, null, eventTime3)), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime3)), + CheckNewAnswer() + ) + } + } + + test("validate ttl update updates the expiration timestamp - event time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .withWatermark("eventTime", "1 second") + .groupByKey(x => x.key) + .transformWithState( + getProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.EventTimeTTL()) + + val eventTime1 = Timestamp.valueOf("2024-01-01 00:00:00") + val eventTime2 = Timestamp.valueOf("2024-01-01 00:02:00") + val ttlExpiration = Timestamp.valueOf("2024-01-01 00:03:00") + val ttlExpirationMs = ttlExpiration.getTime + val eventTime3 = Timestamp.valueOf("2024-01-01 00:05:00") + + testStream(result)( + AddData(inputStream, + InputEvent("k1", "put", 1, null, eventTime1, ttlExpiration)), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", 1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, + InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // remove tll expiration for key k1, and move watermark past previous ttl value + AddData(inputStream, InputEvent("k1", "put", 2, null, eventTime3)), + CheckNewAnswer(), + // validate that the key still exists + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime3)), + CheckNewAnswer(OutputEvent("k1", 2, isTTLValue = false, -1)), + // ensure this ttl expiration time has been removed from state + AddData(inputStream, + InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + ProcessAllAvailable() + ) + } + } +} + +class ValueStateTTLSuite extends TransformWithStateTTLTest { + import testImplicits._ + + override def getProcessor(): StatefulProcessor[String, InputEvent, OutputEvent] = { + new ValueStateTTLProcessor() + } + + test("validate multiple value states - with and without ttl - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val ttlKey = "k1" + val noTtlKey = "k2" + + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + MultipleValueStatesTTLProcessor(ttlKey, noTtlKey), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent(ttlKey, "put", 1, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent(noTtlKey, "put", 2, Duration.ZERO)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get both state values, and make sure we get unexpired value + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer( + OutputEvent(ttlKey, 1, isTTLValue = false, -1), + OutputEvent(noTtlKey, 2, isTTLValue = false, -1) + ), + // ensure ttl values were added correctly, and noTtlKey has no ttl values + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent(ttlKey, "get_values_in_ttl_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + // advance clock after expiry + AdvanceManualClock(60 * 1000), + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + // validate ttlKey is expired, bot noTtlKey is still present + CheckNewAnswer(OutputEvent(noTtlKey, 2, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + + test("validate multiple value states - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val ttlKey = "k1" + val noTtlKey = "k2" + + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + MultipleValueStatesTTLProcessor(ttlKey, noTtlKey), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent(ttlKey, "put", 1, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent(noTtlKey, "put", 2, Duration.ZERO)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get both state values, and make sure we get unexpired value + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer( + OutputEvent(ttlKey, 1, isTTLValue = false, -1), + OutputEvent(noTtlKey, 2, isTTLValue = false, -1) + ), + // ensure ttl values were added correctly, and noTtlKey has no ttl values + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent(ttlKey, "get_values_in_ttl_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + // advance clock after expiry + AdvanceManualClock(60 * 1000), + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + // validate ttlKey is expired, bot noTtlKey is still present + CheckNewAnswer(OutputEvent(noTtlKey, 2, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } +} + +class ListStateTTLSuite extends TransformWithStateTTLTest { + + import testImplicits._ + + override def getProcessor(): StatefulProcessor[String, InputEvent, OutputEvent] = { + new ListStateTTLProcessor() + } + + test("verify iterator works with expired values in middle of list - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + getProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + // Add three elements with duration of a minute + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + AdvanceManualClock(1 * 1000), + AddData(inputStream, InputEvent("k1", "append", 2, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent("k1", "append", 3, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // Add three elements with a duration of 15 seconds + AddData(inputStream, InputEvent("k1", "append", 4, Duration.ofSeconds(15))), + AddData(inputStream, InputEvent("k1", "append", 5, Duration.ofSeconds(15))), + AddData(inputStream, InputEvent("k1", "append", 6, Duration.ofSeconds(15))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // Add three elements with a duration of a minute + AddData(inputStream, InputEvent("k1", "append", 7, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent("k1", "append", 8, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent("k1", "append", 9, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // Advance clock to expire the middle three elements + AdvanceManualClock(30 * 1000), + // Get all elements in the list + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + // Validate that the expired elements are not returned + CheckNewAnswer( + OutputEvent("k1", 1, isTTLValue = false, -1), + OutputEvent("k1", 2, isTTLValue = false, -1), + OutputEvent("k1", 3, isTTLValue = false, -1), + OutputEvent("k1", 7, isTTLValue = false, -1), + OutputEvent("k1", 8, isTTLValue = false, -1), + OutputEvent("k1", 9, isTTLValue = false, -1) + ) + ) + } + } +}