diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 717d5e6631ec1..b81849716b1ae 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3530,6 +3530,12 @@ ], "sqlState" : "0A000" }, + "STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE" : { + "message" : [ + "State store operation= on state= does not support TTL in NoTTL() mode." + ], + "sqlState" : "42802" + }, "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE" : { "message" : [ "Failed to perform stateful processor operation= with invalid handle state=." @@ -4336,6 +4342,11 @@ "Removing column families with is not supported." ] }, + "STATE_STORE_TTL" : { + "message" : [ + "State TTL with is not supported. Please use RocksDBStateStoreProvider." + ] + }, "TABLE_OPERATION" : { "message" : [ "Table does not support . Please check the current catalog and namespace to make sure the qualified table name is expected, and also check the catalog implementation which is configured by \"spark.sql.catalog\"." diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index e5dd67ccf7874..5c0f4a7e8c81a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.ScalarUserDefinedFunction import org.apache.spark.sql.functions.col -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode, TTLMode} /** * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -830,12 +830,15 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { * Instance of statefulProcessor whose functions will be invoked by the operator. * @param timeoutMode * The timeout mode of the stateful processor. + * @param ttlMode + * The ttlMode to evict user state on ttl expiration. * @param outputMode * The output mode of the stateful processor. Defaults to APPEND mode. */ def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode = OutputMode.Append()): Dataset[U] = { throw new UnsupportedOperationException } diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 7b20dfb6bce58..94dfe20af56e7 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -60,6 +60,8 @@ files="sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java"/> + ` is not supported. Removing column families with `` is not supported. +## STATE_STORE_TTL + +State TTL with `` is not supported. Please use RocksDBStateStoreProvider. + ## TABLE_OPERATION Table `` does not support ``. Please check the current catalog and namespace to make sure the qualified table name is expected, and also check the catalog implementation which is configured by "spark.sql.catalog". diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index b05a8d1ff61eb..e499192a50e4f 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -2162,6 +2162,12 @@ The SQL config `` cannot be found. Please verify that the config exists Star (*) is not allowed in a select list when GROUP BY an ordinal position is used. +### STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +State store operation=`` on state=`` does not support TTL in NoTTL() mode. + ### STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE [SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java new file mode 100644 index 0000000000000..06af92dc13210 --- /dev/null +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.catalyst.plans.logical.*; + +/** + * Represents the type of ttl modes possible for the Dataset operations + * {@code transformWithState}. + */ +@Experimental +@Evolving +public class TTLMode { + + /** + * Specifies that there is no TTL for the user state. User state would not + * be cleaned up by Spark automatically. + */ + public static final TTLMode NoTTL() { + return NoTTL$.MODULE$; + } + + /** + * Specifies that all ttl durations for user state are in processing time. + */ + public static final TTLMode ProcessingTimeTTL() { return ProcessingTimeTTL$.MODULE$; } + + /** + * Specifies that all ttl durations for user state are in event time. + */ + public static final TTLMode EventTimeTTL() { return EventTimeTTL$.MODULE$; } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala new file mode 100644 index 0000000000000..4b43060e0d358 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.streaming.TTLMode + +/** TTL types used in tranformWithState operator */ +case object NoTTL extends TTLMode + +case object ProcessingTimeTTL extends TTLMode + +case object EventTimeTTL extends TTLMode diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala index 0e2d6cc3778c6..c3d1ebca4a652 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.streaming +import java.time.Duration + import org.apache.spark.annotation.{Evolving, Experimental} @Experimental @@ -33,13 +35,13 @@ private[sql] trait ListState[S] extends Serializable { def get(): Iterator[S] /** Update the value of the list. */ - def put(newState: Array[S]): Unit + def put(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit /** Append an entry to the list */ - def appendValue(newState: S): Unit + def appendValue(newState: S, ttlDuration: Duration = Duration.ZERO): Unit /** Append an entire list to the existing value */ - def appendList(newState: Array[S]): Unit + def appendList(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit /** Removes this state for the given grouping key. */ def clear(): Unit diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala index 560188a0ff621..30f2d9000ecc0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala @@ -30,10 +30,11 @@ import org.apache.spark.sql.Encoder private[sql] trait StatefulProcessorHandle extends Serializable { /** - * Function to create new or return existing single value state variable of given type + * Function to create new or return existing single value state variable of given type. * The user must ensure to call this function only within the `init()` method of the * StatefulProcessor. - * @param stateName - name of the state variable + * + * @param stateName - name of the state variable * @param valEncoder - SQL encoder for state variable * @tparam T - type of state variable * @return - instance of ValueState of type T that can be used to store state persistently diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala index 9c707c8308abf..11e69dd08bf18 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import java.io.Serializable +import java.time.Duration import org.apache.spark.annotation.{Evolving, Experimental} @@ -42,8 +43,13 @@ private[sql] trait ValueState[S] extends Serializable { /** Get the state if it exists as an option and None otherwise */ def getOption(): Option[S] - /** Update the value of the state. */ - def update(newState: S): Unit + /** + * Update the value of the state. + * @param newState the new value + * @param ttlDuration set the ttl to current batch processing time (for processing time TTL mode) + * or current watermark (for event time ttl mode) plus ttlDuration + */ + def update(newState: S, ttlDuration: Duration = Duration.ZERO): Unit /** Remove this state. */ def clear(): Unit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index cb8673d20ed3d..75152ced43867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode, TTLMode} import org.apache.spark.sql.types._ object CatalystSerde { @@ -574,6 +574,7 @@ object TransformWithState { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[K, V, U], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, child: LogicalPlan): LogicalPlan = { @@ -584,6 +585,7 @@ object TransformWithState { groupingAttributes, dataAttributes, statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]], + ttlMode, timeoutMode, outputMode, keyEncoder.asInstanceOf[ExpressionEncoder[Any]], @@ -600,6 +602,7 @@ case class TransformWithState( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 50ab2a41612b4..d62381f2b823e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.internal.TypedAggUtils -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode, TTLMode} /** * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -656,12 +656,14 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the * operator. * @param timeoutMode The timeout mode of the stateful processor. + * @param ttlMode The ttlMode to evict user state on ttl expiration * @param outputMode The output mode of the stateful processor. Defaults to APPEND mode. * */ private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode = OutputMode.Append()): Dataset[U] = { Dataset[U]( sparkSession, @@ -669,6 +671,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, child = logicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f77d0fef4eb95..5bedcaf3e6e01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -751,7 +751,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case TransformWithState( keyDeserializer, valueDeserializer, groupingAttributes, - dataAttributes, statefulProcessor, timeoutMode, outputMode, + dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputAttr, child) => val execPlan = TransformWithStateExec( keyDeserializer, @@ -759,6 +759,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, keyEncoder, @@ -917,10 +918,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { hasInitialState, planLater(initialState), planLater(child) ) :: Nil case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes, - dataAttributes, statefulProcessor, timeoutMode, outputMode, keyEncoder, + dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputObjAttr, child) => TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, valueDeserializer, - groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode, + groupingAttributes, dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputObjAttr, planLater(child)) :: Nil case _: FlatMapGroupsInPandasWithState => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 662bef5716ea2..14d9a877ea054 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -16,12 +16,15 @@ */ package org.apache.spark.sql.execution.streaming +import java.time.Duration + import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} -import org.apache.spark.sql.streaming.ListState +import org.apache.spark.sql.streaming.{ListState, TTLMode} /** * Provides concrete implementation for list of values associated with a state variable @@ -37,12 +40,18 @@ class ListStateImpl[S]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], - valEncoder: Encoder[S]) - extends ListState[S] with Logging { + valEncoder: Encoder[S], + ttlMode: TTLMode, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]) + extends ListState[S] + with Logging + with StateVariableWithTTLSupport { private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) + private[sql] var ttlState: Option[SingleKeyTTLState] = None store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true) @@ -59,29 +68,65 @@ class ListStateImpl[S]( * empty iterator is returned. */ override def get(): Iterator[S] = { + val encodedKey = stateTypesEncoder.encodeGroupingKey() val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName) + var currentRow: UnsafeRow = null + var shouldGetNextValidRow = true + var isFirst = true + new Iterator[S] { override def hasNext: Boolean = { - unsafeRowValuesIterator.hasNext + if (shouldGetNextValidRow) { + getNextValidRow() + } + currentRow != null } override def next(): S = { - val valueUnsafeRow = unsafeRowValuesIterator.next() - stateTypesEncoder.decodeValue(valueUnsafeRow) + if (shouldGetNextValidRow) { + getNextValidRow() + } + if (currentRow == null) { + throw new NoSuchElementException("Iterator is at the end") + } + shouldGetNextValidRow = true + stateTypesEncoder.decodeValue(currentRow) + } + + // sets currentRow to a valid state, where we are + // pointing to a non-expired row + private def getNextValidRow(): Unit = { + assert(shouldGetNextValidRow) + while (unsafeRowValuesIterator.hasNext && (isFirst || isExpired(currentRow))) { + isFirst = false + currentRow = unsafeRowValuesIterator.next() + } + // in this case, we have iterated to the end, and there are no + // non-expired values + if (isExpired(currentRow)) { + currentRow = null + } + shouldGetNextValidRow = false } } } /** Update the value of the list. */ - override def put(newState: Array[S]): Unit = { + override def put(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit = { validateNewState(newState) val encodedKey = stateTypesEncoder.encodeGroupingKey() var isFirst = true + var expirationMs: Long = -1 + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + expirationMs = StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } + newState.foreach { v => - val encodedValue = stateTypesEncoder.encodeValue(v) + val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) if (isFirst) { store.put(encodedKey, encodedValue, stateName) isFirst = false @@ -89,24 +134,42 @@ class ListStateImpl[S]( store.merge(encodedKey, encodedValue, stateName) } } + ttlState.foreach(_.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey())) } /** Append an entry to the list. */ - override def appendValue(newState: S): Unit = { + override def appendValue(newState: S, ttlDuration: Duration = Duration.ZERO): Unit = { StateStoreErrors.requireNonNullStateValue(newState, stateName) + var expirationMs: Long = -1 + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + expirationMs = StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } + val encodedValue = stateTypesEncoder.encodeValue(newState, expirationMs) store.merge(stateTypesEncoder.encodeGroupingKey(), - stateTypesEncoder.encodeValue(newState), stateName) + encodedValue, stateName) + ttlState.foreach(_.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey())) } /** Append an entire list to the existing value. */ - override def appendList(newState: Array[S]): Unit = { + override def appendList(newState: Array[S], ttlDuration: Duration = Duration.ZERO): Unit = { validateNewState(newState) + var expirationMs: Long = -1 + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + expirationMs = StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } + val encodedKey = stateTypesEncoder.encodeGroupingKey() newState.foreach { v => - val encodedValue = stateTypesEncoder.encodeValue(v) + val encodedValue = stateTypesEncoder.encodeValue(v, expirationMs) store.merge(encodedKey, encodedValue, stateName) } + ttlState.foreach(_.upsertTTLForStateKey(expirationMs, + stateTypesEncoder.serializeGroupingKey())) } /** Remove this state. */ @@ -122,4 +185,48 @@ class ListStateImpl[S]( StateStoreErrors.requireNonNullStateValue(v, stateName) } } - } + + /** + * Clears the user state associated with this grouping key + * if it has expired. This function is called by Spark to perform + * cleanup at the end of transformWithState processing. + * + * Spark uses a secondary index to determine if the user state for + * this grouping key has expired. However, its possible that the user + * has updated the TTL and secondary index is out of date. Implementations + * must validate that the user State has actually expired before cleanup based + * on their own State data. + * + * @param groupingKey grouping key for which cleanup should be performed. + */ + override def clearIfExpired(groupingKey: Array[Byte]): Unit = { + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName) + clear() + var isFirst = true + + unsafeRowValuesIterator.foreach { encodedValue => + val encodedGroupingKey = stateTypesEncoder.encodeSerializedGroupingKey(groupingKey) + if (encodedValue != null) { + if (isExpired(encodedValue)) { + store.remove(encodedGroupingKey, stateName) + } else { + if (isFirst) { + store.put(encodedKey, encodedValue, stateName) + isFirst = false + } else { + store.merge(encodedKey, encodedValue, stateName) + } + } + } + } + } + + private def isExpired(valueRow: UnsafeRow): Boolean = { + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow) + val isExpired = expirationMs.map( + StateTTL.isExpired(ttlMode, _, batchTimestampMs, eventTimeWatermarkMs)) + + isExpired.isDefined && isExpired.get + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index d2ccd0a778074..c58f32ed756db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -45,7 +45,7 @@ class MapStateImpl[K, V]( /** Whether state exists or not. */ override def exists(): Boolean = { - !store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).isEmpty + store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).nonEmpty } /** Get the state value if it exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index 1d41db896cdf2..f97b91ccc9606 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -23,11 +23,19 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.state.StateStoreErrors -import org.apache.spark.sql.types.{BinaryType, StructType} +import org.apache.spark.sql.types.{BinaryType, LongType, StructType} object StateKeyValueRowSchema { val KEY_ROW_SCHEMA: StructType = new StructType().add("key", BinaryType) - val VALUE_ROW_SCHEMA: StructType = new StructType().add("value", BinaryType) + val VALUE_ROW_SCHEMA: StructType = new StructType() + .add("value", BinaryType) + .add("ttlExpirationMs", LongType) + + def encodeGroupingKeyBytes(keyBytes: Array[Byte]): UnsafeRow = { + val keyProjection = UnsafeProjection.create(KEY_ROW_SCHEMA) + val keyRow = keyProjection(InternalRow(keyBytes)) + keyRow + } } /** @@ -62,33 +70,79 @@ class StateTypesEncoder[GK, V]( private val rowToObjDeserializer = valExpressionEnc.resolveAndBind().createDeserializer() private val reusedValRow = new UnsafeRow(valEncoder.schema.fields.length) + private val NO_TTL_ENCODED_VALUE: Long = -1L + // TODO: validate places that are trying to encode the key and check if we can eliminate/ // add caching for some of these calls. def encodeGroupingKey(): UnsafeRow = { + val keyRow = keyProjection(InternalRow(serializeGroupingKey())) + keyRow + } + + /** + * Encodes the provided grouping key into Spark UnsafeRow. + * + * @param groupingKeyBytes serialized grouping key byte array + * @return encoded UnsafeRow + */ + def encodeSerializedGroupingKey( + groupingKeyBytes: Array[Byte]): UnsafeRow = { + val keyRow = keyProjection(InternalRow(groupingKeyBytes)) + keyRow + } + + def serializeGroupingKey(): Array[Byte] = { val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption if (keyOption.isEmpty) { throw StateStoreErrors.implicitKeyNotFound(stateName) } - val groupingKey = keyOption.get.asInstanceOf[GK] - val keyByteArr = keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() - val keyRow = keyProjection(InternalRow(keyByteArr)) - keyRow + keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() } + /** + * Encode the specified value in Spark UnsafeRow with no ttl. + * The ttl expiration will be set to -1, specifying no TTL. + */ def encodeValue(value: V): UnsafeRow = { val objRow: InternalRow = objToRowSerializer.apply(value) val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() - val valRow = valueProjection(InternalRow(bytes)) + val valRow = valueProjection(InternalRow(bytes, NO_TTL_ENCODED_VALUE)) valRow } + /** + * Encode the specified value in Spark UnsafeRow + * with provided ttl expiration. + */ + def encodeValue(value: V, expirationMs: Long = -1): UnsafeRow = { + val objRow: InternalRow = objToRowSerializer.apply(value) + val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() + val valueRow = valueProjection(InternalRow(bytes, expirationMs)) + valueRow + } + def decodeValue(row: UnsafeRow): V = { val bytes = row.getBinary(0) reusedValRow.pointTo(bytes, bytes.length) val value = rowToObjDeserializer.apply(reusedValRow) value } + + /** + * Decode the ttl information out of Value row. If the ttl has + * not been set (-1L specifies no user defined value), the API will + * return None. + */ + def decodeTtlExpirationMs(row: UnsafeRow): Option[Long] = { + val expirationMs = row.getLong(1) + + if (expirationMs == NO_TTL_ENCODED_VALUE) { + None + } else { + Some(expirationMs) + } + } } object StateTypesEncoder { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala new file mode 100644 index 0000000000000..286b926577fbd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateVariableWithTTLSupport.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.time.Duration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.streaming.TTLMode +import org.apache.spark.sql.types.{BinaryType, DataType, LongType, NullType, StructField, StructType} + +object StateTTLSchema { + val KEY_ROW_SCHEMA: StructType = new StructType() + .add("expirationMs", LongType) + .add("groupingKey", BinaryType) + val VALUE_ROW_SCHEMA: StructType = + StructType(Array(StructField("__dummy__", NullType))) +} + +/** + * Encapsulates the ttl row information stored in [[SingleKeyTTLState]]. + * @param groupingKey grouping key for which ttl is set + * @param expirationMs expiration time for the grouping key + */ +case class SingleKeyTTLRow( + groupingKey: Array[Byte], + expirationMs: Long) + +/** + * Represents a State variable which supports TTL. + */ +trait StateVariableWithTTLSupport { + + /** + * Clears the user state associated with this grouping key + * if it has expired. This function is called by Spark to perform + * cleanup at the end of transformWithState processing. + * + * Spark uses a secondary index to determine if the user state for + * this grouping key has expired. However, its possible that the user + * has updated the TTL and secondary index is out of date. Implementations + * must validate that the user State has actually expired before cleanup based + * on their own State data. + * + * @param groupingKey grouping key for which cleanup should be performed. + */ + def clearIfExpired(groupingKey: Array[Byte]): Unit +} + +/** + * Represents the underlying state for secondary TTL Index for a user defined + * state variable. + * + * This state allows Spark to query ttl values based on expiration time + * allowing efficient ttl cleanup. + */ +trait TTLState { + + /** + * Perform the user state clean yp based on ttl values stored in + * this state. NOTE that its not safe to call this operation concurrently + * when the user can also modify the underlying State. Cleanup should be initiated + * after arbitrary state operations are completed by the user. + */ + def clearExpiredState(): Unit +} + +/** + * Manages the ttl information for user state keyed with a single key (grouping key). + */ +class SingleKeyTTLState( + ttlMode: TTLMode, + stateName: String, + store: StateStore, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]) + extends TTLState + with Logging { + + import org.apache.spark.sql.execution.streaming.StateTTLSchema._ + + private val ttlColumnFamilyName = s"_ttl_$stateName" + private val ttlKeyEncoder = UnsafeProjection.create(KEY_ROW_SCHEMA) + private var state: StateVariableWithTTLSupport = _ + + // empty row used for values + private val EMPTY_ROW = + UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) + + // TODO: use range scan once Range Scan PR is merged for StateStore + store.createColFamilyIfAbsent(ttlColumnFamilyName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), isInternal = true) + + def upsertTTLForStateKey( + expirationMs: Long, + groupingKey: Array[Byte]): Unit = { + val encodedTtlKey = ttlKeyEncoder(InternalRow(expirationMs, groupingKey)) + store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName) + } + + override def clearExpiredState(): Unit = { + store.iterator(ttlColumnFamilyName).foreach { kv => + val expirationMs = kv.key.getLong(0) + val isExpired = StateTTL.isExpired(ttlMode, expirationMs, + batchTimestampMs, eventTimeWatermarkMs) + + if (isExpired) { + val groupingKey = kv.key.getBinary(1) + state.clearIfExpired(groupingKey) + + store.remove(kv.key, ttlColumnFamilyName) + } + } + } + + private[sql] def setStateVariable( + state: StateVariableWithTTLSupport): Unit = { + this.state = state + } + + private[sql] def iterator(): Iterator[SingleKeyTTLRow] = { + val ttlIterator = store.iterator(ttlColumnFamilyName) + + new Iterator[SingleKeyTTLRow] { + override def hasNext: Boolean = ttlIterator.hasNext + + override def next(): SingleKeyTTLRow = { + val kv = ttlIterator.next() + SingleKeyTTLRow( + expirationMs = kv.key.getLong(0), + groupingKey = kv.key.getBinary(1) + ) + } + } + } +} + +/** + * Helper methods for user State TTL. + */ +object StateTTL { + def calculateExpirationTimeForDuration( + ttlMode: TTLMode, + ttlDuration: Duration, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Long = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get + ttlDuration.toMillis + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get + ttlDuration.toMillis + } else { + throw new IllegalStateException(s"cannot calculate expiration time for" + + s" unknown ttl Mode $ttlMode") + } + } + + def isExpired( + ttlMode: TTLMode, + expirationMs: Long, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Boolean = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get >= expirationMs + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get >= expirationMs + } else { + throw new IllegalStateException(s"cannot evaluate expiry condition for" + + s" unknown ttl Mode $ttlMode") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 9b905ad5235db..b50e5507a4d1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.streaming +import java.util import java.util.UUID import org.apache.spark.TaskContext @@ -24,7 +25,7 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TimeoutMode, TTLMode, ValueState} import org.apache.spark.util.Utils /** @@ -77,14 +78,20 @@ class StatefulProcessorHandleImpl( store: StateStore, runId: UUID, keyEncoder: ExpressionEncoder[Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, - isStreaming: Boolean = true) + isStreaming: Boolean = true, + batchTimestampMs: Option[Long] = None, + eventTimeWatermarkMs: Option[Long] = None) extends StatefulProcessorHandle with Logging { import StatefulProcessorHandleState._ + private val ttlStates: util.List[TTLState] = new util.ArrayList[TTLState]() + private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" - private def buildQueryInfo(): QueryInfo = { + logInfo(s"Created StatefulProcessorHandle") + private def buildQueryInfo(): QueryInfo = { val taskCtxOpt = Option(TaskContext.get()) val (queryId, batchId) = if (!isStreaming) { (BATCH_QUERY_ID, 0L) @@ -103,21 +110,26 @@ class StatefulProcessorHandleImpl( private var currState: StatefulProcessorHandleState = CREATED - private def verify(condition: => Boolean, msg: String): Unit = { - if (!condition) { - throw new IllegalStateException(msg) - } - } - def setHandleState(newState: StatefulProcessorHandleState): Unit = { currState = newState } def getHandleState: StatefulProcessorHandleState = currState - override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { + override def getValueState[T]( + stateName: String, + valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state") - val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) + + val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder, + ttlMode, batchTimestampMs, eventTimeWatermarkMs) + val ttlState = resultState.ttlState + + ttlState.foreach { s => + ttlStates.add(s) + s.setStateVariable(resultState) + } + resultState } @@ -183,6 +195,16 @@ class StatefulProcessorHandleImpl( timerState.listTimers() } + /** + * Performs the user state cleanup based on assigned TTl values. Any state + * which is expired will be cleaned up from StateStore. + */ + def doTtlCleanup(): Unit = { + ttlStates.forEach { s => + s.clearExpiredState() + } + } + /** * Function to delete and purge state variable if defined previously * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorTTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorTTLState.scala new file mode 100644 index 0000000000000..4cdfe9e2f9ecc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorTTLState.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.time.Duration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTTL, NoTTL, ProcessingTimeTTL} +import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore} +import org.apache.spark.sql.streaming.TTLMode +import org.apache.spark.sql.types.{BinaryType, DataType, LongType, NullType, StructField, StructType} + +class StatefulProcessorTTLState( + ttlMode: TTLMode, + stateName: String, + store: StateStore, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long], + createColumnFamily: Boolean = true) { + + private val ttlColumnFamilyName = s"_ttl_$stateName" + + private val schemaForKeyRow = new StructType() + .add("expirationMs", LongType) + .add("groupingKey", BinaryType) + .add("userKey", BinaryType) + private val schemaForValueRow: StructType = + StructType(Array(StructField("__dummy__", NullType))) + + private val ttlKeyEncoder = UnsafeProjection.create(schemaForKeyRow) + + // empty row used for values + private val EMPTY_ROW = + UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) + + validate() + if (createColumnFamily) { + store.createColFamilyIfAbsent(ttlColumnFamilyName, schemaForKeyRow, + schemaForValueRow, PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1), isInternal = true) + } + + private def validate(): Unit = { + ttlMode match { + case NoTTL => + throw new RuntimeException() + case ProcessingTimeTTL if batchTimestampMs.isEmpty => + throw new IllegalStateException() + case EventTimeTTL if eventTimeWatermarkMs.isEmpty => + throw new IllegalStateException() + case _ => + } + } + + private def expiredKeysIterator(): Iterator[InternalRow] = { + // TODO(sahnib): Need to merge Anish's changes + store.iterator(ttlColumnFamilyName).foreach { kv => + + } + + Iterator.empty + } + + def upsertTTLForStateKey( + expirationMs: Long, + groupingKey: Array[Byte], + userKey: Option[Array[Byte]]): Unit = { + val encodedTtlKey = ttlKeyEncoder(InternalRow(expirationMs, + groupingKey, userKey.orNull)) + store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName) + } +} + +object StatefulProcessorTTLState { + def apply( + ttlMode: TTLMode, + stateName: String, + store: StateStore, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): StatefulProcessorTTLState = { + new StatefulProcessorTTLState(ttlMode, stateName, store, batchTimestampMs, eventTimeWatermarkMs) + } + + def calculateExpirationTimeForDuration( + ttlMode: TTLMode, + ttlDuration: Duration, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Long = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get + ttlDuration.toMillis + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get + ttlDuration.toMillis + } else { + -1L + } + } + + def getCurrentExpirationTime( + ttlMode: TTLMode, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Long = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get + } else { + -1L + } + } + + def isExpired( + ttlMode: TTLMode, + expirationMs: Long, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]): Boolean = { + if (ttlMode == TTLMode.ProcessingTimeTTL()) { + batchTimestampMs.get > expirationMs + } else if (ttlMode == TTLMode.EventTimeTTL()) { + eventTimeWatermarkMs.get > expirationMs + } else { + false + } + } + + def encodedTtlModeValue(ttLMode: TTLMode): Short = { + ttLMode match { + case NoTTL => + 0 + case ProcessingTimeTTL => + 1 + case EventTimeTTL => + 2 + } + } + + def decodedTtlMode(encodedVal: Short): TTLMode = { + encodedVal match { + case 0 => + TTLMode.NoTTL() + case 1 => + TTLMode.ProcessingTimeTTL() + case 2 => + TTLMode.EventTimeTTL() + case _ => + throw new IllegalStateException("encodedTtlValue should be <= 2") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala index 6166374d25e94..170c7c205a8ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -78,25 +78,25 @@ class TimerStateImpl( private val secIndexKeyEncoder = UnsafeProjection.create(keySchemaForSecIndex) - val timerCFName = if (timeoutMode == TimeoutMode.ProcessingTime) { + private val timerCFName = if (timeoutMode == TimeoutMode.ProcessingTime) { TimerStateUtils.PROC_TIMERS_STATE_NAME } else { TimerStateUtils.EVENT_TIMERS_STATE_NAME } - val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF + private val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF store.createColFamilyIfAbsent(keyToTsCFName, schemaForKeyRow, schemaForValueRow, PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1), useMultipleValuesPerKey = false, isInternal = true) - val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF + private val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex, schemaForValueRow, NoPrefixKeyStateEncoderSpec(keySchemaForSecIndex), useMultipleValuesPerKey = false, isInternal = true) private def getGroupingKey(cfName: String): Any = { val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption - if (!keyOption.isDefined) { + if (keyOption.isEmpty) { throw StateStoreErrors.implicitKeyNotFound(cfName) } keyOption.get diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 39365e92185ad..3a1b31422dbc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, TimeoutMode} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, TimeoutMode, TTLMode} import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Utils} /** @@ -40,6 +40,7 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Uti * @param groupingAttributes used to group the data * @param dataAttributes used to read the data * @param statefulProcessor processor methods called on underlying data + * @param ttlMode defines the ttl Mode for user state * @param timeoutMode defines the timeout mode * @param outputMode defines the output mode for the statefulProcessor * @param keyEncoder expression encoder for the key type @@ -56,6 +57,7 @@ case class TransformWithStateExec( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], @@ -90,16 +92,13 @@ case class TransformWithStateExec( override def keyExpressions: Seq[Attribute] = groupingAttributes - protected val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) - - protected val schemaForValueRow: StructType = new StructType().add("value", BinaryType) - override def requiredChildDistribution: Seq[Distribution] = { StatefulOperatorPartitioning.getCompatibleDistribution(groupingAttributes, getStateInfo, conf) :: Nil } + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( groupingAttributes.map(SortOrder(_, Ascending))) @@ -241,6 +240,8 @@ case class TransformWithStateExec( allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) commitTimeMs += timeTakenMs { if (isStreaming) { + // clean up any expired user state + processorHandle.doTtlCleanup() store.commit() } else { store.abort() @@ -257,26 +258,15 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver - timeoutMode match { - case ProcessingTime => - if (batchTimestampMs.isEmpty) { - StateStoreErrors.missingTimeoutValues(timeoutMode.toString) - } - - case EventTime => - if (eventTimeWatermarkForEviction.isEmpty) { - StateStoreErrors.missingTimeoutValues(timeoutMode.toString) - } - - case _ => - } + validateTTLMode() + validateTimeoutMode() if (isStreaming) { child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, - schemaForKeyRow, - schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), session.sqlContext.sessionState, Some(session.sqlContext.streams.stateStoreCoordinator), useColumnFamilies = true, @@ -306,9 +296,9 @@ case class TransformWithStateExec( // Create StateStoreProvider for this partition val stateStoreProvider = StateStoreProvider.createAndInit( providerId, - schemaForKeyRow, - schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useColumnFamilies = true, storeConf = storeConf, hadoopConf = broadcastedHadoopConf.value, @@ -334,15 +324,49 @@ case class TransformWithStateExec( private def processData(store: StateStore, singleIterator: Iterator[InternalRow]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl( - store, getStateInfo.queryRunId, keyEncoder, timeoutMode, isStreaming) + store, getStateInfo.queryRunId, keyEncoder, ttlMode, timeoutMode, + isStreaming, batchTimestampMs, eventTimeWatermarkForEviction) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) statefulProcessor.init(outputMode, timeoutMode) processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) processDataWithPartition(singleIterator, store, processorHandle) } + + private def validateTimeoutMode(): Unit = { + timeoutMode match { + case ProcessingTime => + if (batchTimestampMs.isEmpty) { + StateStoreErrors.missingTimeoutValues(timeoutMode.toString) + } + + case EventTime => + if (eventTimeWatermarkForEviction.isEmpty) { + StateStoreErrors.missingTimeoutValues(timeoutMode.toString) + } + + case _ => + } + } + + private def validateTTLMode(): Unit = { + ttlMode match { + case ProcessingTimeTTL => + if (batchTimestampMs.isEmpty) { + StateStoreErrors.missingTTLValues(timeoutMode.toString) + } + + case EventTimeTTL => + if (eventTimeWatermarkForEviction.isEmpty) { + StateStoreErrors.missingTTLValues(timeoutMode.toString) + } + + case _ => + } + } } +// scalastyle:off argcount object TransformWithStateExec { // Plan logical transformWithState for batch queries @@ -352,6 +376,7 @@ object TransformWithStateExec { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], @@ -372,6 +397,7 @@ object TransformWithStateExec { groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, keyEncoder, @@ -384,3 +410,6 @@ object TransformWithStateExec { isStreaming = false) } } + +// scalastyle:on argcount + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 08876ca3032ee..8d9eaeb39f919 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -16,39 +16,61 @@ */ package org.apache.spark.sql.execution.streaming +import java.time.Duration + import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} -import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.streaming.{TTLMode, ValueState} /** * Class that provides a concrete implementation for a single value state associated with state * variables used in the streaming transformWithState operator. * @param store - reference to the StateStore instance to be used for storing state * @param stateName - name of logical state partition - * @param keyEnc - Spark SQL encoder for key + * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value + * @param ttlMode - TTL Mode for values stored in this state + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermarkMs event time watermark for streaming query + * (same as watermark for state eviction) * @tparam S - data type of object that will be stored */ class ValueStateImpl[S]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], - valEncoder: Encoder[S]) extends ValueState[S] with Logging { + valEncoder: Encoder[S], + ttlMode: TTLMode, + batchTimestampMs: Option[Long], + eventTimeWatermarkMs: Option[Long]) + extends ValueState[S] + with Logging + with StateVariableWithTTLSupport { private val keySerializer = keyExprEnc.createSerializer() - private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) + private[sql] var ttlState: Option[SingleKeyTTLState] = None + + initialize() - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + private def initialize(): Unit = { + store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + + if (ttlMode != TTLMode.NoTTL()) { + val _ttlState = new SingleKeyTTLState(ttlMode, stateName, store, + batchTimestampMs, eventTimeWatermarkMs) + ttlState = Some(_ttlState) + } + } /** Function to check if state exists. Returns true if present and false otherwise */ override def exists(): Boolean = { - getImpl() != null + get() != null } /** Function to return Option of value if exists and None otherwise */ @@ -58,26 +80,138 @@ class ValueStateImpl[S]( /** Function to return associated value with key if exists and null otherwise */ override def get(): S = { - val retRow = getImpl() + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + if (retRow != null) { - stateTypesEncoder.decodeValue(retRow) + val resState = stateTypesEncoder.decodeValue(retRow) + + if (!isExpired(retRow)) { + resState + } else { + null.asInstanceOf[S] + } } else { null.asInstanceOf[S] } } - private def getImpl(): UnsafeRow = { - store.get(stateTypesEncoder.encodeGroupingKey(), stateName) - } - /** Function to update and overwrite state associated with given key */ - override def update(newState: S): Unit = { - store.put(stateTypesEncoder.encodeGroupingKey(), - stateTypesEncoder.encodeValue(newState), stateName) + override def update( + newState: S, + ttlDuration: Duration = Duration.ZERO): Unit = { + + if (ttlDuration != Duration.ZERO && ttlState.isEmpty) { + throw StateStoreErrors.cannotProvideTTLDurationForNoTTLMode("update", stateName) + } + + var expirationMs: Long = -1 + if (ttlDuration != null && ttlDuration != Duration.ZERO) { + expirationMs = StateTTL.calculateExpirationTimeForDuration( + ttlMode, ttlDuration, batchTimestampMs, eventTimeWatermarkMs) + } + + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey), + stateTypesEncoder.encodeValue(newState, expirationMs), stateName) + + ttlState.foreach(_.upsertTTLForStateKey(expirationMs, serializedGroupingKey)) } /** Function to remove state for given key */ override def clear(): Unit = { store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) } + + def clearIfExpired(groupingKey: Array[Byte]): Unit = { + val encodedGroupingKey = stateTypesEncoder.encodeSerializedGroupingKey(groupingKey) + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + if (isExpired(retRow)) { + store.remove(encodedGroupingKey, stateName) + } + } + } + + private def isExpired(valueRow: UnsafeRow): Boolean = { + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow) + val isExpired = expirationMs.map( + StateTTL.isExpired(ttlMode, _, batchTimestampMs, eventTimeWatermarkMs)) + + isExpired.isDefined && isExpired.get + } + + /* + * Internal methods to probe state for testing. The below methods exist for unit tests + * to read the state ttl values, and ensure that values are persisted correctly in + * the underlying state store. + */ + + /** + * Retrieves the value from State even if its expired. This method is used + * in tests to read the state store value, and ensure if its cleaned up at the + * end of the micro-batch. + */ + private[sql] def getWithoutEnforcingTTL(): Option[S] = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + val resState = stateTypesEncoder.decodeValue(retRow) + Some(resState) + } else { + None + } + } + + /** + * Read the ttl value associated with the grouping key. + */ + private[sql] def getTTLValue(): Option[Long] = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + stateTypesEncoder.decodeTtlExpirationMs(retRow) + } else { + None + } + } + + /** + * Get all ttl values stored in ttl state for current implicit + * grouping key. + */ + private[sql] def getValuesInTTLState(): Iterator[Long] = { + if (ttlState.isEmpty) { + Iterator.empty + } + + val ttlIterator = ttlState.get.iterator() + val implicitGroupingKey = stateTypesEncoder.serializeGroupingKey() + var nextValue: Option[Long] = None + + new Iterator[Long] { + override def hasNext: Boolean = { + while (nextValue.isEmpty && ttlIterator.hasNext) { + val nextTtlValue = ttlIterator.next() + val groupingKey = nextTtlValue.groupingKey + + if (groupingKey sameElements implicitGroupingKey) { + nextValue = Some(nextTtlValue.expirationMs) + } + } + + nextValue.isDefined + } + + override def next(): Long = { + val result = nextValue.get + nextValue = None + + result + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 1817104a5c223..1c1159f654546 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -342,6 +342,10 @@ class RocksDB( } } + def listColumnFamilies(): Seq[String] = { + Seq() + } + /** * Remove RocksDB column family, if exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index c8537f2a6a5b1..c6db9b5d000c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -257,6 +257,13 @@ private[sql] class RocksDBStateStoreProvider rocksDB.removeColFamilyIfExists(colFamilyName) keyValueEncoderMap.remove(colFamilyName) } + + /** Return a list of column family names */ + override def listColumnFamilies(): Seq[String] = { + verify(useColumnFamilies, "Column families are not supported in this store") + // turn keys of keyValueEncoderMap to Seq + rocksDB.listColumnFamilies() + } } override def init( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index dd97aa5b9afca..c79c244dcf443 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -122,6 +122,11 @@ trait StateStore extends ReadStateStore { */ def removeColFamilyIfExists(colFamilyName: String): Unit + /** + * Return list of column family names. + */ + def listColumnFamilies(): Seq[String] = Seq(StateStore.DEFAULT_COL_FAMILY_NAME) + /** * Create column family with given name, if absent. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index a8d4c06bc83c4..cad31b088e96c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -39,6 +39,13 @@ object StateStoreErrors { ) } + def missingTTLValues(ttlMode: String): SparkException = { + SparkException.internalError( + msg = s"Failed to find timeout values for ttlMode=$ttlMode", + category = "TWS" + ) + } + def unsupportedOperationOnMissingColumnFamily(operationName: String, colFamilyName: String): StateStoreUnsupportedOperationOnMissingColumnFamily = { new StateStoreUnsupportedOperationOnMissingColumnFamily(operationName, colFamilyName) @@ -49,6 +56,11 @@ object StateStoreErrors { new StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider) } + def ttlNotSupportedWithProvider(stateStoreProvider: String): + StateStoreTTLNotSupportedException = { + new StateStoreTTLNotSupportedException(stateStoreProvider) + } + def removingColumnFamiliesNotSupported(stateStoreProvider: String): StateStoreRemovingColumnFamiliesNotSupportedException = { new StateStoreRemovingColumnFamiliesNotSupportedException(stateStoreProvider) @@ -112,12 +124,24 @@ object StateStoreErrors { handleState: String): StatefulProcessorCannotPerformOperationWithInvalidHandleState = { new StatefulProcessorCannotPerformOperationWithInvalidHandleState(operationType, handleState) } + + def cannotProvideTTLDurationForNoTTLMode(operationType: String, + stateName: String): StatefulProcessorCannotAssignTTLInNoTTLMode = { + new StatefulProcessorCannotAssignTTLInNoTTLMode(operationType, stateName) + } } class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_MULTIPLE_COLUMN_FAMILIES", - messageParameters = Map("stateStoreProvider" -> stateStoreProvider)) + messageParameters = Map("stateStoreProvider" -> stateStoreProvider) + ) + +class StateStoreTTLNotSupportedException(stateStoreProvider: String) + extends SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_TTL", + messageParameters = Map("stateStoreProvider" -> stateStoreProvider) + ) class StateStoreRemovingColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( @@ -182,3 +206,10 @@ class StateStoreNullTypeOrderingColsNotSupported(fieldName: String, index: Strin extends SparkUnsupportedOperationException( errorClass = "STATE_STORE_NULL_TYPE_ORDERING_COLS_NOT_SUPPORTED", messageParameters = Map("fieldName" -> fieldName, "index" -> index)) + +class StatefulProcessorCannotAssignTTLInNoTTLMode( + operationType: String, + stateName: String) + extends SparkUnsupportedOperationException( + errorClass = "STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE", + messageParameters = Map("operationType" -> operationType, "stateName" -> stateName)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala index e895e475b74d9..51cfc1548b398 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} -import org.apache.spark.sql.streaming.{ListState, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, TimeoutMode, TTLMode, ValueState} /** * Class that adds unit tests for ListState types used in arbitrary stateful @@ -37,7 +37,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val listState: ListState[Long] = handle.getListState[Long]("listState", Encoders.scalaLong) @@ -47,7 +48,7 @@ class ListStateSuite extends StateVariableSuiteBase { } checkError( - exception = e.asInstanceOf[SparkIllegalArgumentException], + exception = e, errorClass = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE", sqlState = Some("42601"), parameters = Map("stateName" -> "listState") @@ -70,7 +71,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ListState[Long] = handle.getListState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") @@ -98,7 +100,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: ListState[Long] = handle.getListState[Long]("testState1", Encoders.scalaLong) val testState2: ListState[Long] = handle.getListState[Long]("testState2", Encoders.scalaLong) @@ -136,7 +139,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val listState1: ListState[Long] = handle.getListState[Long]("listState1", Encoders.scalaLong) val listState2: ListState[Long] = handle.getListState[Long]("listState2", Encoders.scalaLong) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala index ce72061d39ea2..7fa41b12795eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala @@ -22,7 +22,7 @@ import java.util.UUID import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} -import org.apache.spark.sql.streaming.{ListState, MapState, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, MapState, TimeoutMode, TTLMode, ValueState} import org.apache.spark.sql.types.{BinaryType, StructType} /** @@ -39,7 +39,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: MapState[String, Double] = handle.getMapState[String, Double]("testState", Encoders.STRING, Encoders.scalaDouble) @@ -73,7 +74,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: MapState[Long, Double] = handle.getMapState[Long, Double]("testState1", Encoders.scalaLong, Encoders.scalaDouble) @@ -112,7 +114,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val mapTestState1: MapState[String, Int] = handle.getMapState[String, Int]("mapTestState1", Encoders.STRING, Encoders.scalaInt) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala index 662a5dbfaac4f..aec828459fce5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} -import org.apache.spark.sql.streaming.TimeoutMode +import org.apache.spark.sql.streaming.{TimeoutMode, TTLMode} + /** * Class that adds tests to verify operations based on stateful processor handle @@ -48,7 +49,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) assert(handle.getHandleState === StatefulProcessorHandleState.CREATED) handle.getValueState[Long]("testState", Encoders.scalaLong) } @@ -89,7 +90,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) Seq(StatefulProcessorHandleState.INITIALIZED, StatefulProcessorHandleState.DATA_PROCESSED, @@ -107,7 +108,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, TimeoutMode.NoTimeouts()) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val ex = intercept[SparkUnsupportedOperationException] { handle.registerTimer(10000L) } @@ -143,7 +144,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) handle.setHandleState(StatefulProcessorHandleState.INITIALIZED) assert(handle.getHandleState === StatefulProcessorHandleState.INITIALIZED) @@ -164,7 +165,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) handle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED) assert(handle.getHandleState === StatefulProcessorHandleState.DATA_PROCESSED) @@ -204,7 +205,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) Seq(StatefulProcessorHandleState.CREATED, StatefulProcessorHandleState.TIMER_PROCESSED, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 8668b58672c7e..8063c2cdb155e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{TimeoutMode, TTLMode, ValueState} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -48,7 +48,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val stateName = "testState" val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -78,7 +79,7 @@ class ValueStateSuite extends StateVariableSuiteBase { testState.update(123) } checkError( - ex.asInstanceOf[SparkException], + ex1.asInstanceOf[SparkException], errorClass = "INTERNAL_ERROR_TWS", parameters = Map( "message" -> s"Implicit key not found in state store for stateName=$stateName" @@ -92,7 +93,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") @@ -118,7 +120,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: ValueState[Long] = handle.getValueState[Long]( "testState1", Encoders.scalaLong) @@ -164,7 +167,7 @@ class ValueStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], - TimeoutMode.NoTimeouts()) + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val cfName = "_testState" val ex = intercept[SparkUnsupportedOperationException] { @@ -204,7 +207,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Double] = handle.getValueState[Double]("testState", Encoders.scalaDouble) @@ -230,7 +234,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -256,7 +261,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[TestClass] = handle.getValueState[TestClass]("testState", Encoders.product[TestClass]) @@ -282,7 +288,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[POJOTestClass] = handle.getValueState[POJOTestClass]("testState", Encoders.bean(classOf[POJOTestClass])) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 95ab34d401311..51cc9ff87890c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -140,6 +140,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update()) ( @@ -160,6 +161,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -180,6 +182,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -200,6 +203,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -220,6 +224,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -240,6 +245,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -260,6 +266,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update()) ( @@ -312,6 +319,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x) .transformWithState(new ToggleSaveAndEmitProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index d7c5ce3815b04..46ed7e3fb3f5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -94,6 +94,7 @@ class TransformWithMapStateSuite extends StreamTest { .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) @@ -120,6 +121,7 @@ class TransformWithMapStateSuite extends StreamTest { .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -144,6 +146,7 @@ class TransformWithMapStateSuite extends StreamTest { .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -167,6 +170,7 @@ class TransformWithMapStateSuite extends StreamTest { .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) testStream(result, OutputMode.Append())( // Test exists() @@ -221,6 +225,7 @@ class TransformWithMapStateSuite extends StreamTest { .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) val df = result.toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 24b0d59c45c56..79422d6fa83ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -231,6 +231,7 @@ class RunningCountMostRecentStatefulProcessor _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) _mostRecent = getHandle.getValueState[String]("mostRecent", Encoders.STRING) } + override def handleInputRows( key: String, inputRows: Iterator[(String, String)], @@ -310,6 +311,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessorWithError(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -331,6 +333,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -361,6 +364,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -404,6 +408,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new RunningCountStatefulProcessorWithProcTimeTimerUpdates(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -440,6 +445,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new RunningCountStatefulProcessorWithMultipleTimers(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -475,6 +481,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new MaxEventTimeStatefulProcessor(), TimeoutMode.EventTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -516,6 +523,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) val df = result.toDF() @@ -534,12 +542,14 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x._1) .transformWithState(new RunningCountMostRecentStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) val stream2 = inputData.toDS() .groupByKey(x => x._1) .transformWithState(new MostRecentStatefulProcessorWithDeletion(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(stream1, OutputMode.Update())( @@ -572,6 +582,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -605,6 +616,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -638,6 +650,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -668,6 +681,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) } @@ -760,6 +774,7 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest { .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala new file mode 100644 index 0000000000000..a1acf3689d0ad --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLSuite.scala @@ -0,0 +1,526 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.sql.Timestamp +import java.time.{Duration, Instant} +import java.time.temporal.ChronoUnit + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.execution.streaming.{MemoryStream, ValueStateImpl} +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StreamManualClock + +case class InputEvent( + key: String, + action: String, + value: Int, + ttl: Duration, + eventTime: Timestamp = null) + +case class OutputEvent( + key: String, + value: Int, + isTTLValue: Boolean, + ttlValue: Long) + +object TTLInputProcessFunction { + def processRow( + row: InputEvent, + valueState: ValueStateImpl[Int]): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + val key = row.key + if (row.action == "get") { + val currState = valueState.getOption() + if (currState.isDefined) { + results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_without_enforcing_ttl") { + val currState = valueState.getWithoutEnforcingTTL() + if (currState.isDefined) { + results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_ttl_value_from_state") { + val ttlExpiration = valueState.getTTLValue() + if (ttlExpiration.isDefined) { + results = OutputEvent(key, -1, isTTLValue = true, ttlExpiration.get) :: results + } + } else if (row.action == "put") { + valueState.update(row.value, row.ttl) + } else if (row.action == "get_values_in_ttl_state") { + val ttlValues = valueState.getValuesInTTLState() + ttlValues.foreach { v => + results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: results + } + } + + results.iterator + } +} + +class ValueStateTTLProcessor + extends StatefulProcessor[String, InputEvent, OutputEvent] + with Logging { + + @transient private var _valueState: ValueStateImpl[Int] = _ + + override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { + _valueState = getHandle + .getValueState("valueState", Encoders.scalaInt) + .asInstanceOf[ValueStateImpl[Int]] + } + + override def handleInputRows( + key: String, + inputRows: Iterator[InputEvent], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + + for (row <- inputRows) { + val resultIter = TTLInputProcessFunction.processRow(row, _valueState) + resultIter.foreach { r => + results = r :: results + } + } + + results.iterator + } +} + +case class MultipleValueStatesTTLProcessor( + ttlKey: String, + noTtlKey: String) + extends StatefulProcessor[String, InputEvent, OutputEvent] + with Logging { + + @transient private var _valueStateWithTTL: ValueStateImpl[Int] = _ + @transient private var _valueStateWithoutTTL: ValueStateImpl[Int] = _ + + override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { + _valueStateWithTTL = getHandle + .getValueState("valueState", Encoders.scalaInt) + .asInstanceOf[ValueStateImpl[Int]] + _valueStateWithoutTTL = getHandle + .getValueState("valueState", Encoders.scalaInt) + .asInstanceOf[ValueStateImpl[Int]] + } + + override def handleInputRows( + key: String, + inputRows: Iterator[InputEvent], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + val state = if (key == ttlKey) { + _valueStateWithTTL + } else { + _valueStateWithoutTTL + } + + for (row <- inputRows) { + val resultIterator = TTLInputProcessFunction.processRow(row, state) + resultIterator.foreach { r => + results = r :: results + } + } + results.iterator + } +} + +class TransformWithStateTTLSuite + extends StreamTest { + import testImplicits._ + + test("validate state is evicted at ttl expiry - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock so that state expires + AdvanceManualClock(60 * 1000), + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + // validate expired value is not returned + CheckNewAnswer(), + // ensure this state does not exist any longer in State + AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + test("validate ttl update updates the expiration timestamp - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock and update expiration time + AdvanceManualClock(30 * 1000), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent("k1", "get", -1, null)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + // validate value is not expired + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate ttl value is updated in the state + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)), + // validate ttl state has both ttl values present + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000), + OutputEvent("k1", -1, isTTLValue = true, 95000) + ), + // advance clock after older expiration value + AdvanceManualClock(30 * 1000), + // ensure unexpired value is still present in the state + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate that the older expiration value is removed from ttl state + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)) + ) + } + } + + test("validate ttl removal keeps value in state - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent("k1", "put", 1, Duration.ofMinutes(1))), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock and update state to remove ttl + AdvanceManualClock(30 * 1000), + AddData(inputStream, InputEvent("k1", "put", 1, null)), + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + // validate value is not expired + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // validate ttl state still has old ttl value present + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock after older expiration value + AdvanceManualClock(30 * 1000), + // ensure unexpired value is still present in the state + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate that the older expiration value is removed from ttl state + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + test("validate multiple value states - with and without ttl - processing time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val ttlKey = "k1" + val noTtlKey = "k2" + + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + MultipleValueStatesTTLProcessor(ttlKey, noTtlKey), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent(ttlKey, "put", 1, Duration.ofMinutes(1))), + AddData(inputStream, InputEvent(noTtlKey, "put", 2, Duration.ZERO)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get both state values, and make sure we get unexpired value + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer( + OutputEvent(ttlKey, 1, isTTLValue = false, -1), + OutputEvent(noTtlKey, 2, isTTLValue = false, -1) + ), + // ensure ttl values were added correctly, and noTtlKey has no ttl values + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent(ttlKey, "get_values_in_ttl_state", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get_values_in_ttl_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + // advance clock after expiry + AdvanceManualClock(60 * 1000), + AddData(inputStream, InputEvent(ttlKey, "get", -1, null)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1, null)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + // validate ttlKey is expired, bot noTtlKey is still present + CheckNewAnswer(OutputEvent(noTtlKey, 2, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1, null)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + test("validate state is evicted at ttl expiry - event time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .withWatermark("eventTime", "1 second") + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.EventTimeTTL()) + + val eventTime1 = Timestamp.from(Instant.EPOCH.plus(30, ChronoUnit.SECONDS)) + val eventTime2 = Timestamp.from(Instant.EPOCH.plus(60, ChronoUnit.SECONDS)) + val ttlDuration = Duration.ofMinutes(1) + val ttlExpirationMs = Instant.EPOCH.plus(ttlDuration).toEpochMilli + val eventTime3 = Timestamp.from(Instant.ofEpochMilli(ttlExpirationMs + 1000)) + + testStream(result)( + AddData(inputStream, + InputEvent("k1", "put", 1, ttlDuration, eventTime1)), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, + InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // increment event time so that key k1 expires + AddData(inputStream, InputEvent("k2", "put", 1, ttlDuration, eventTime3)), + CheckNewAnswer(), + // validate that k1 has expired + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime3)), + CheckNewAnswer(), + // ensure this state does not exist any longer in State + AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", -1, null, eventTime3)), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime3)), + CheckNewAnswer() + ) + } + } + + test("validate ttl update updates the expiration timestamp - event time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .withWatermark("eventTime", "1 second") + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.EventTimeTTL()) + + val eventTime1 = Timestamp.from(Instant.EPOCH.plus(30, ChronoUnit.SECONDS)) + val eventTime2 = Timestamp.from(Instant.EPOCH.plus(60, ChronoUnit.SECONDS)) + val ttlDuration = Duration.ofMinutes(1) + val ttlExpirationMs = Instant.EPOCH.plus(ttlDuration).toEpochMilli + val eventTime3 = Timestamp.from(Instant.ofEpochMilli(ttlExpirationMs + 1000)) + + testStream(result)( + AddData(inputStream, + InputEvent("k1", "put", 1, ttlDuration, eventTime1)), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", 1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, + InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + ProcessAllAvailable(), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // remove tll expiration for key k1, and move watermark past previous ttl value + AddData(inputStream, InputEvent("k1", "put", 2, null, eventTime3)), + CheckNewAnswer(), + // validate that the key still exists + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime3)), + CheckNewAnswer(OutputEvent("k1", 2, isTTLValue = false, -1)), + // ensure this ttl expiration time has been removed from state + AddData(inputStream, + InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + ProcessAllAvailable() + ) + } + } + + test("validate ttl removal keeps value in state - event time ttl") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .withWatermark("eventTime", "1 second") + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(), + TimeoutMode.NoTimeouts(), + TTLMode.EventTimeTTL()) + + val eventTime1 = Timestamp.from(Instant.EPOCH.plus(30, ChronoUnit.SECONDS)) + val eventTime2 = Timestamp.from(Instant.EPOCH.plus(60, ChronoUnit.SECONDS)) + val ttlDuration = Duration.ofMinutes(1) + val ttlExpirationMs = Instant.EPOCH.plus(ttlDuration).toEpochMilli + val eventTime3 = Timestamp.from(Instant.ofEpochMilli(ttlExpirationMs + 1000)) + + testStream(result)( + AddData(inputStream, InputEvent("k1", "put", 1, ttlDuration, eventTime1)), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", 1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime2)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // update state and remove ttl + AddData(inputStream, InputEvent("k1", "put", 2, null, eventTime2)), + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime2)), + // validate value is not expired + CheckNewAnswer(OutputEvent("k1", 2, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null, eventTime2)), + CheckNewAnswer(), + // validate ttl state still has old ttl value present + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null, eventTime3)), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, ttlExpirationMs)), + // eventTime has been advanced to eventTim3 which is after older expiration value + // ensure unexpired value is still present in the state + AddData(inputStream, InputEvent("k1", "get", -1, null, eventTime3)), + CheckNewAnswer(OutputEvent("k1", 2, isTTLValue = false, -1)), + // validate that the older expiration value is removed from ttl state + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, null)), + CheckNewAnswer() + ) + } + } +}