Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{ListState, TTLConfig}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.NextIterator
Expand All @@ -36,6 +36,10 @@ import org.apache.spark.util.NextIterator
* @param ttlConfig - TTL configuration for values stored in this state
* @param batchTimestampMs - current batch processing timestamp.
* @param metrics - metrics to be updated as part of stateful processing
* @param avroEnc - optional Avro serializer and deserializer for this state variable that
* is used by the StateStore to encode state in Avro format
* @param ttlAvroEnc - optional Avro serializer and deserializer for TTL state that
* is used by the StateStore to encode state in Avro format
* @tparam S - data type of object that will be stored
*/
class ListStateImplWithTTL[S](
Expand All @@ -45,8 +49,10 @@ class ListStateImplWithTTL[S](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty)
extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs)
metrics: Map[String, SQLMetric] = Map.empty,
avroEnc: Option[AvroEncoderSpec] = None,
ttlAvroEnc: Option[AvroEncoderSpec] = None)
extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs, ttlAvroEnc)
with ListStateMetricsImpl
with ListState[S] {

Expand All @@ -65,7 +71,8 @@ class ListStateImplWithTTL[S](
private def initialize(): Unit = {
store.createColFamilyIfAbsent(stateName, keyExprEnc.schema,
getValueSchemaWithTTL(valEncoder.schema, true),
NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), useMultipleValuesPerKey = true)
NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), useMultipleValuesPerKey = true,
avroEncoderSpec = avroEnc)
}

/** Whether state exists or not. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{MapState, TTLConfig}
import org.apache.spark.util.NextIterator

Expand All @@ -36,6 +36,10 @@ import org.apache.spark.util.NextIterator
* @param ttlConfig - the ttl configuration (time to live duration etc.)
* @param batchTimestampMs - current batch processing timestamp.
* @param metrics - metrics to be updated as part of stateful processing
* @param avroEnc - optional Avro serializer and deserializer for this state variable that
* is used by the StateStore to encode state in Avro format
* @param ttlAvroEnc - optional Avro serializer and deserializer for TTL state that
* is used by the StateStore to encode state in Avro format
* @tparam K - type of key for map state variable
* @tparam V - type of value for map state variable
* @return - instance of MapState of type [K,V] that can be used to store state persistently
Expand All @@ -48,9 +52,11 @@ class MapStateImplWithTTL[K, V](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty)
metrics: Map[String, SQLMetric] = Map.empty,
avroEnc: Option[AvroEncoderSpec] = None,
ttlAvroEnc: Option[AvroEncoderSpec] = None)
extends CompositeKeyTTLStateImpl[K](stateName, store,
keyExprEnc, userKeyEnc, batchTimestampMs)
keyExprEnc, userKeyEnc, batchTimestampMs, ttlAvroEnc)
with MapState[K, V] with Logging {

private val stateTypesEncoder = new CompositeKeyStateEncoder(
Expand All @@ -66,7 +72,8 @@ class MapStateImplWithTTL[K, V](
getCompositeKeySchema(keyExprEnc.schema, userKeyEnc.schema)
store.createColFamilyIfAbsent(stateName, schemaForCompositeKeyRow,
getValueSchemaWithTTL(valEncoder.schema, true),
PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1))
PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1),
avroEncoderSpec = avroEnc)
}

/** Whether state exists or not. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,45 @@ import org.apache.spark.sql.Encoder
import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema}
import org.apache.spark.sql.execution.streaming.state.AvroEncoderSpec
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, StateStoreColFamilySchema}
import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StructField, StructType}

object StateStoreColumnFamilySchemaUtils {

def apply(initializeAvroSerde: Boolean): StateStoreColumnFamilySchemaUtils =
new StateStoreColumnFamilySchemaUtils(initializeAvroSerde)


/**
* Avro uses zig-zag encoding for some fixed-length types, like Longs and Ints. For range scans
* we want to use big-endian encoding, so we need to convert the source schema to replace these
* types with BinaryType.
*
* @param schema The schema to convert
* @param ordinals If non-empty, only convert fields at these ordinals.
* If empty, convert all fields.
*/
def convertForRangeScan(schema: StructType, ordinals: Seq[Int] = Seq.empty): StructType = {
val ordinalSet = ordinals.toSet
StructType(schema.fields.zipWithIndex.map { case (field, idx) =>
if ((ordinals.isEmpty || ordinalSet.contains(idx)) && isFixedSize(field.dataType)) {
// Convert numeric types to BinaryType while preserving nullability
field.copy(dataType = BinaryType)
} else {
field
}
})
}

private def isFixedSize(dataType: DataType): Boolean = dataType match {
case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _: LongType |
_: FloatType | _: DoubleType => true
case _ => false
}

def getTtlColFamilyName(stateName: String): String = {
"$ttl_" + stateName
}
}

/**
Expand All @@ -45,7 +76,7 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) {
private def getAvroSerde(
keySchema: StructType,
valSchema: StructType,
userKeySchema: Option[StructType] = None
suffixKeySchema: Option[StructType] = None
): Option[AvroEncoderSpec] = {
if (initializeAvroSerde) {
val avroType = SchemaConverters.toAvroType(valSchema)
Expand All @@ -59,18 +90,18 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) {
val valueDeserializer = new AvroDeserializer(avroType, valSchema,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
val (userKeySerializer, userKeyDeserializer) = if (userKeySchema.isDefined) {
val userKeyAvroType = SchemaConverters.toAvroType(userKeySchema.get)
val ukSer = new AvroSerializer(userKeySchema.get, userKeyAvroType, nullable = false)
val ukDe = new AvroDeserializer(userKeyAvroType, userKeySchema.get,
val (suffixKeySer, suffixKeyDe) = if (suffixKeySchema.isDefined) {
val userKeyAvroType = SchemaConverters.toAvroType(suffixKeySchema.get)
val skSer = new AvroSerializer(suffixKeySchema.get, userKeyAvroType, nullable = false)
val skDe = new AvroDeserializer(userKeyAvroType, suffixKeySchema.get,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
(Some(ukSer), Some(ukDe))
(Some(skSer), Some(skDe))
} else {
(None, None)
}
Some(AvroEncoderSpec(
keySer, keyDe, valueSerializer, valueDeserializer, userKeySerializer, userKeyDeserializer))
keySer, keyDe, valueSerializer, valueDeserializer, suffixKeySer, suffixKeyDe))
} else {
None
}
Expand Down Expand Up @@ -126,6 +157,47 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) {
)
}

def getTtlStateSchema(
stateName: String,
keyEncoder: ExpressionEncoder[Any]): StateStoreColFamilySchema = {
val ttlKeySchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan(
getSingleKeyTTLRowSchema(keyEncoder.schema), Seq(0))
val ttlValSchema = StructType(
Array(StructField("__dummy__", NullType)))
StateStoreColFamilySchema(
stateName,
ttlKeySchema,
ttlValSchema,
Some(RangeKeyScanStateEncoderSpec(ttlKeySchema, Seq(0))),
avroEnc = getAvroSerde(
StructType(ttlKeySchema.take(1)),
ttlValSchema,
Some(StructType(ttlKeySchema.drop(1)))
)
)
}

def getTtlStateSchema(
stateName: String,
keyEncoder: ExpressionEncoder[Any],
userKeySchema: StructType): StateStoreColFamilySchema = {
val ttlKeySchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan(
getCompositeKeyTTLRowSchema(keyEncoder.schema, userKeySchema), Seq(0))
val ttlValSchema = StructType(
Array(StructField("__dummy__", NullType)))
StateStoreColFamilySchema(
stateName,
ttlKeySchema,
ttlValSchema,
Some(RangeKeyScanStateEncoderSpec(ttlKeySchema, Seq(0))),
avroEnc = getAvroSerde(
StructType(ttlKeySchema.take(1)),
ttlValSchema,
Some(StructType(ttlKeySchema.drop(1)))
)
)
}

def getTimerStateSchema(
stateName: String,
keySchema: StructType,
Expand All @@ -134,6 +206,29 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) {
stateName,
keySchema,
valSchema,
Some(PrefixKeyScanStateEncoderSpec(keySchema, 1)))
Some(PrefixKeyScanStateEncoderSpec(keySchema, 1)),
avroEnc = getAvroSerde(
StructType(keySchema.take(1)),
valSchema,
Some(StructType(keySchema.drop(1)))
))
}

def getTimerStateSchemaForSecIndex(
stateName: String,
keySchema: StructType,
valSchema: StructType): StateStoreColFamilySchema = {
val avroKeySchema = StateStoreColumnFamilySchemaUtils.
convertForRangeScan(keySchema, Seq(0))
StateStoreColFamilySchema(
stateName,
keySchema,
valSchema,
Some(RangeKeyScanStateEncoderSpec(keySchema, Seq(0))),
avroEnc = getAvroSerde(
StructType(avroKeySchema.take(1)),
valSchema,
Some(StructType(avroKeySchema.drop(1)))
))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleState.PRE_INIT
import org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils.getTtlColFamilyName
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, TimeMode, TTLConfig, ValueState}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -140,7 +141,13 @@ class StatefulProcessorHandleImpl(

override def getQueryInfo(): QueryInfo = currQueryInfo

private lazy val timerState = new TimerStateImpl(store, timeMode, keyEncoder)
private lazy val timerStateName = TimerStateUtils.getTimerStateVarName(
timeMode.toString)
private lazy val timerSecIndexColFamily = TimerStateUtils.getSecIndexColFamilyName(
timeMode.toString)
private lazy val timerState = new TimerStateImpl(
store, timeMode, keyEncoder, schemas(timerStateName).avroEnc,
schemas(timerSecIndexColFamily).avroEnc)

/**
* Function to register a timer for the given expiryTimestampMs
Expand Down Expand Up @@ -382,10 +389,16 @@ class DriverStatefulProcessorHandleImpl(

private def addTimerColFamily(): Unit = {
val stateName = TimerStateUtils.getTimerStateVarName(timeMode.toString)
val secIndexColFamilyName = TimerStateUtils.getSecIndexColFamilyName(timeMode.toString)
val timerEncoder = new TimerKeyEncoder(keyExprEnc)
val colFamilySchema = schemaUtils.
getTimerStateSchema(stateName, timerEncoder.schemaForKeyRow, timerEncoder.schemaForValueRow)
val secIndexColFamilySchema = schemaUtils.
getTimerStateSchemaForSecIndex(secIndexColFamilyName,
timerEncoder.keySchemaForSecIndex,
timerEncoder.schemaForValueRow)
columnFamilySchemas.put(stateName, colFamilySchema)
columnFamilySchemas.put(secIndexColFamilyName, secIndexColFamilySchema)
val stateVariableInfo = TransformWithStateVariableUtils.getTimerState(stateName)
stateVariableInfos.put(stateName, stateVariableInfo)
}
Expand All @@ -404,6 +417,9 @@ class DriverStatefulProcessorHandleImpl(
val ttlEnabled = if (ttlConfig.ttlDuration != null && ttlConfig.ttlDuration.isZero) {
false
} else {
val ttlColFamilyName = getTtlColFamilyName(stateName)
val ttlColFamilySchema = schemaUtils.getTtlStateSchema(ttlColFamilyName, keyExprEnc)
columnFamilySchemas.put(ttlColFamilyName, ttlColFamilySchema)
true
}

Expand Down Expand Up @@ -432,6 +448,9 @@ class DriverStatefulProcessorHandleImpl(
val ttlEnabled = if (ttlConfig.ttlDuration != null && ttlConfig.ttlDuration.isZero) {
false
} else {
val ttlColFamilyName = getTtlColFamilyName(stateName)
val ttlColFamilySchema = schemaUtils.getTtlStateSchema(ttlColFamilyName, keyExprEnc)
columnFamilySchemas.put(ttlColFamilyName, ttlColFamilySchema)
true
}

Expand Down Expand Up @@ -459,14 +478,19 @@ class DriverStatefulProcessorHandleImpl(
ttlConfig: TTLConfig): MapState[K, V] = {
verifyStateVarOperations("get_map_state", PRE_INIT)

val userKeyEnc = encoderFor[K]
val valEncoder = encoderFor[V]
val ttlEnabled = if (ttlConfig.ttlDuration != null && ttlConfig.ttlDuration.isZero) {
false
} else {
val ttlColFamilyName = getTtlColFamilyName(stateName)
val ttlColFamilySchema = schemaUtils.getTtlStateSchema(
ttlColFamilyName, keyExprEnc, userKeyEnc.schema)
columnFamilySchemas.put(ttlColFamilyName, ttlColFamilySchema)
true
}

val userKeyEnc = encoderFor[K]
val valEncoder = encoderFor[V]

val colFamilySchema = schemaUtils.
getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, ttlEnabled)
columnFamilySchemas.put(stateName, colFamilySchema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ import java.time.Duration
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils.getTtlColFamilyName
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{RangeKeyScanStateEncoderSpec, StateStore}
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, RangeKeyScanStateEncoderSpec, StateStore}
import org.apache.spark.sql.types._

object StateTTLSchema {
Expand Down Expand Up @@ -79,12 +80,13 @@ abstract class SingleKeyTTLStateImpl(
stateName: String,
store: StateStore,
keyExprEnc: ExpressionEncoder[Any],
ttlExpirationMs: Long)
ttlExpirationMs: Long,
avroEnc: Option[AvroEncoderSpec] = None)
extends TTLState {

import org.apache.spark.sql.execution.streaming.StateTTLSchema._

private val ttlColumnFamilyName = "$ttl_" + stateName
private val ttlColumnFamilyName = getTtlColFamilyName(stateName)
private val keySchema = getSingleKeyTTLRowSchema(keyExprEnc.schema)
private val keyTTLRowEncoder = new SingleKeyTTLEncoder(keyExprEnc)

Expand All @@ -93,7 +95,7 @@ abstract class SingleKeyTTLStateImpl(
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))

store.createColFamilyIfAbsent(ttlColumnFamilyName, keySchema, TTL_VALUE_ROW_SCHEMA,
RangeKeyScanStateEncoderSpec(keySchema, Seq(0)), isInternal = true)
RangeKeyScanStateEncoderSpec(keySchema, Seq(0)), isInternal = true, avroEncoderSpec = avroEnc)

/**
* This function will be called when clear() on State Variables
Expand Down Expand Up @@ -199,12 +201,13 @@ abstract class CompositeKeyTTLStateImpl[K](
store: StateStore,
keyExprEnc: ExpressionEncoder[Any],
userKeyEncoder: ExpressionEncoder[Any],
ttlExpirationMs: Long)
ttlExpirationMs: Long,
avroEnc: Option[AvroEncoderSpec] = None)
extends TTLState {

import org.apache.spark.sql.execution.streaming.StateTTLSchema._

private val ttlColumnFamilyName = "$ttl_" + stateName
private val ttlColumnFamilyName = getTtlColFamilyName(stateName)
private val keySchema = getCompositeKeyTTLRowSchema(
keyExprEnc.schema, userKeyEncoder.schema
)
Expand All @@ -218,7 +221,7 @@ abstract class CompositeKeyTTLStateImpl[K](

store.createColFamilyIfAbsent(ttlColumnFamilyName, keySchema,
TTL_VALUE_ROW_SCHEMA, RangeKeyScanStateEncoderSpec(keySchema,
Seq(0)), isInternal = true)
Seq(0)), isInternal = true, avroEncoderSpec = avroEnc)

def clearTTLState(): Unit = {
val iterator = store.iterator(ttlColumnFamilyName)
Expand Down
Loading
Loading