Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ class ListStateImpl[S](
if (usingAvro) {
val encodedKey: Array[Byte] = avroTypesEncoder.encodeGroupingKey()
store.remove(encodedKey, stateName)
val entryCount = getEntryCount(encodedKey)
TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", entryCount)
removeEntryCount(encodedKey)
// val entryCount = getEntryCount(encodedKey)
// TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", entryCount)
// removeEntryCount(encodedKey)
} else {
val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey()
store.remove(encodedKey, stateName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,17 @@ class ListStateImplWithTTL[S](
ttlConfig: TTLConfig,
batchTimestampMs: Long,
avroEnc: Option[AvroEncoderSpec], // TODO: Add Avro Encoding support for TTL
ttlAvroEnc: Option[AvroEncoderSpec],
metrics: Map[String, SQLMetric] = Map.empty)
extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs)
extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs, ttlAvroEnc)
with ListStateMetricsImpl
with ListState[S] {

override def stateStore: StateStore = store
override def baseStateName: String = stateName
override def exprEncSchema: StructType = keyExprEnc.schema

private lazy val stateTypesEncoder = UnsafeRowTypesEncoder(keyExprEnc, valEncoder,
private lazy val unsafeRowTypesEncoder = UnsafeRowTypesEncoder(keyExprEnc, valEncoder,
stateName, hasTtl = true)

private lazy val ttlExpirationMs =
Expand All @@ -80,19 +81,19 @@ class ListStateImplWithTTL[S](
* empty iterator is returned.
*/
override def get(): Iterator[S] = {
val encodedKey = stateTypesEncoder.encodeGroupingKey()
val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey()
val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)

new NextIterator[S] {

override protected def getNext(): S = {
val iter = unsafeRowValuesIterator.dropWhile { row =>
stateTypesEncoder.isExpired(row, batchTimestampMs)
unsafeRowTypesEncoder.isExpired(row, batchTimestampMs)
}

if (iter.hasNext) {
val currentRow = iter.next()
stateTypesEncoder.decodeValue(currentRow)
unsafeRowTypesEncoder.decodeValue(currentRow)
} else {
finished = true
null.asInstanceOf[S]
Expand All @@ -107,13 +108,13 @@ class ListStateImplWithTTL[S](
override def put(newState: Array[S]): Unit = {
validateNewState(newState)

val encodedKey = stateTypesEncoder.encodeGroupingKey()
val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey()
var isFirst = true
var entryCount = 0L
TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows")

newState.foreach { v =>
val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs)
val encodedValue = unsafeRowTypesEncoder.encodeValue(v, ttlExpirationMs)
if (isFirst) {
store.put(encodedKey, encodedValue, stateName)
isFirst = false
Expand All @@ -130,10 +131,10 @@ class ListStateImplWithTTL[S](
/** Append an entry to the list. */
override def appendValue(newState: S): Unit = {
StateStoreErrors.requireNonNullStateValue(newState, stateName)
val encodedKey = stateTypesEncoder.encodeGroupingKey()
val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey()
val entryCount = getEntryCount(encodedKey)
store.merge(encodedKey,
stateTypesEncoder.encodeValue(newState, ttlExpirationMs), stateName)
unsafeRowTypesEncoder.encodeValue(newState, ttlExpirationMs), stateName)
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
upsertTTLForStateKey(encodedKey)
updateEntryCount(encodedKey, entryCount + 1)
Expand All @@ -143,10 +144,10 @@ class ListStateImplWithTTL[S](
override def appendList(newState: Array[S]): Unit = {
validateNewState(newState)

val encodedKey = stateTypesEncoder.encodeGroupingKey()
val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey()
var entryCount = getEntryCount(encodedKey)
newState.foreach { v =>
val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs)
val encodedValue = unsafeRowTypesEncoder.encodeValue(v, ttlExpirationMs)
store.merge(encodedKey, encodedValue, stateName)
entryCount += 1
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
Expand All @@ -157,7 +158,7 @@ class ListStateImplWithTTL[S](

/** Remove this state. */
override def clear(): Unit = {
val encodedKey = stateTypesEncoder.encodeGroupingKey()
val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey()
store.remove(encodedKey, stateName)
val entryCount = getEntryCount(encodedKey)
TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", entryCount)
Expand Down Expand Up @@ -188,7 +189,7 @@ class ListStateImplWithTTL[S](
var isFirst = true
var entryCount = 0L
unsafeRowValuesIterator.foreach { encodedValue =>
if (!stateTypesEncoder.isExpired(encodedValue, batchTimestampMs)) {
if (!unsafeRowTypesEncoder.isExpired(encodedValue, batchTimestampMs)) {
if (isFirst) {
store.put(groupingKey, encodedValue, stateName)
isFirst = false
Expand All @@ -205,6 +206,10 @@ class ListStateImplWithTTL[S](
numValuesExpired
}

override def clearIfExpired(groupingKey: Array[Byte]): Long = {
0
}

private def upsertTTLForStateKey(encodedGroupingKey: UnsafeRow): Unit = {
upsertTTLForStateKey(ttlExpirationMs, encodedGroupingKey)
}
Expand All @@ -221,22 +226,22 @@ class ListStateImplWithTTL[S](
* end of the micro-batch.
*/
private[sql] def getWithoutEnforcingTTL(): Iterator[S] = {
val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
val encodedGroupingKey = unsafeRowTypesEncoder.encodeGroupingKey()
val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, stateName)
unsafeRowValuesIterator.map { valueUnsafeRow =>
stateTypesEncoder.decodeValue(valueUnsafeRow)
unsafeRowTypesEncoder.decodeValue(valueUnsafeRow)
}
}

/**
* Read the ttl value associated with the grouping key.
*/
private[sql] def getTTLValues(): Iterator[(S, Long)] = {
val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
val encodedGroupingKey = unsafeRowTypesEncoder.encodeGroupingKey()
val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, stateName)
unsafeRowValuesIterator.map { valueUnsafeRow =>
(stateTypesEncoder.decodeValue(valueUnsafeRow),
stateTypesEncoder.decodeTtlExpirationMs(valueUnsafeRow).get)
(unsafeRowTypesEncoder.decodeValue(valueUnsafeRow),
unsafeRowTypesEncoder.decodeTtlExpirationMs(valueUnsafeRow).get)
}
}

Expand All @@ -245,6 +250,6 @@ class ListStateImplWithTTL[S](
* grouping key.
*/
private[sql] def getValuesInTTLState(): Iterator[Long] = {
getValuesInTTLState(stateTypesEncoder.encodeGroupingKey())
getValuesInTTLState(unsafeRowTypesEncoder.encodeGroupingKey())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MapStateImpl[K, V](
keyExprEnc: ExpressionEncoder[Any],
userKeyEnc: Encoder[K],
valEncoder: Encoder[V],
avroSerde: Option[AvroEncoderSpec],
avroEnc: Option[AvroEncoderSpec],
metrics: Map[String, SQLMetric] = Map.empty) extends MapState[K, V] with Logging {

// Pack grouping key and user key together as a prefixed composite key
Expand All @@ -53,9 +53,9 @@ class MapStateImpl[K, V](

// If we are using Avro, the avroSerde parameter must be populated
// else, we will default to using UnsafeRow.
private val usingAvro: Boolean = avroSerde.isDefined
private val usingAvro: Boolean = avroEnc.isDefined
private val avroTypesEncoder = new CompositeKeyAvroRowEncoder(
keyExprEnc, userKeyEnc, valEncoder, stateName, hasTtl = false, avroSerde)
keyExprEnc, userKeyEnc, valEncoder, stateName, hasTtl = false, avroEnc)
private val unsafeRowTypesEncoder = new CompositeKeyUnsafeRowEncoder(
keyExprEnc, userKeyEnc, valEncoder, stateName, hasTtl = false)

Expand Down Expand Up @@ -163,4 +163,4 @@ class MapStateImpl[K, V](
removeKey(key)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@
*/
package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema}
import org.apache.spark.sql.execution.streaming.state.AvroEncoderSpec
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, StateStoreColFamilySchema, TTLRangeKeyScanStateEncoderSpec}
import org.apache.spark.sql.types.{BinaryType, NullType, StructField, StructType}

object StateStoreColumnFamilySchemaUtils {

def apply(initializeAvroSerde: Boolean): StateStoreColumnFamilySchemaUtils =
new StateStoreColumnFamilySchemaUtils(initializeAvroSerde)
}

class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) {
class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) extends Logging {

private def getAvroSerde(
keySchema: StructType,
Expand All @@ -41,6 +41,9 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) {
val avroOptions = AvroOptions(Map.empty)
val keyAvroType = SchemaConverters.toAvroType(keySchema)
val keySer = new AvroSerializer(keySchema, keyAvroType, nullable = false)
val keyDe = new AvroDeserializer(keyAvroType, keySchema,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
val ser = new AvroSerializer(valSchema, avroType, nullable = false)
val de = new AvroDeserializer(avroType, valSchema,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
Expand All @@ -55,12 +58,41 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) {
(Some(ckSer), Some(ckDe))
case None => (None, None)
}
Some(AvroEncoderSpec(keySer, ser, de, ckSerDe._1, ckSerDe._2))
Some(AvroEncoderSpec(keySer, keyDe, ser, de, ckSerDe._1, ckSerDe._2))
} else {
None
}
}

def getTtlStateSchema[T](
stateName: String,
keyEncoder: ExpressionEncoder[Any]): StateStoreColFamilySchema = {
val ttlKeySchema = new StructType()
.add("expirationMs", BinaryType)
.add("groupingKey", BinaryType)
val ttlValSchema = StructType(Array(StructField("__dummy__", NullType)))
StateStoreColFamilySchema(
stateName,
ttlKeySchema,
ttlValSchema,
Some(TTLRangeKeyScanStateEncoderSpec(ttlKeySchema, Seq(0))),
avroEnc = getAvroSerde(ttlKeySchema, ttlValSchema))
}

def getTtlStateSchema[T](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
userKeyEncoder: ExpressionEncoder[Any]): StateStoreColFamilySchema = {
val ttlKeySchema = getCompositeKeyTTLRowSchema(keyEncoder.schema, userKeyEncoder.schema)
val ttlValSchema = StructType(Array(StructField("__dummy__", NullType)))
StateStoreColFamilySchema(
stateName,
ttlKeySchema,
ttlValSchema,
Some(RangeKeyScanStateEncoderSpec(ttlKeySchema, Seq(0))),
avroEnc = getAvroSerde(keyEncoder.schema, ttlValSchema))
}

def getValueStateSchema[T](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
package org.apache.spark.sql.execution.streaming

import java.io.ByteArrayOutputStream
import java.nio.ByteBuffer

import org.apache.avro.generic.{GenericDatumReader, GenericDatumWriter}
import org.apache.avro.io.{DecoderFactory, EncoderFactory}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.avro.SchemaConverters
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -49,6 +51,11 @@ object TransformWithStateKeyValueRowSchemaUtils {
.add("expirationMs", LongType)
.add("groupingKey", keySchema)

def getSingleKeyTTLAvroRowSchema: StructType =
new StructType()
.add("expirationMs", BinaryType)
.add("groupingKey", BinaryType)

def getCompositeKeyTTLRowSchema(
groupingKeySchema: StructType,
userKeySchema: StructType): StructType =
Expand Down Expand Up @@ -188,7 +195,7 @@ class AvroTypesEncoder[V](
valEncoder: Encoder[V],
stateName: String,
hasTtl: Boolean,
avroSerde: Option[AvroEncoderSpec]) extends StateTypesEncoder[V, Array[Byte]] {
avroSerde: Option[AvroEncoderSpec]) extends StateTypesEncoder[V, Array[Byte]] with Logging {

val out = new ByteArrayOutputStream

Expand All @@ -202,7 +209,7 @@ class AvroTypesEncoder[V](
private val keyAvroType = SchemaConverters.toAvroType(keySchema)

// case class -> dataType
private val valSchema: StructType = valEncoder.schema
private val valSchema: StructType = getValueSchemaWithTTL(valEncoder.schema, hasTtl)
// dataType -> avroType
private val valueAvroType = SchemaConverters.toAvroType(valSchema)

Expand Down Expand Up @@ -251,15 +258,38 @@ class AvroTypesEncoder[V](
}

override def encodeValue(value: V, expirationMs: Long): Array[Byte] = {
throw new UnsupportedOperationException
val objRow: InternalRow = objToRowSerializer.apply(value).copy() // V -> InternalRow
val avroData =
avroSerde.get.valueSerializer.serialize(
InternalRow(objRow, expirationMs)) // InternalRow -> GenericDataRecord
out.reset()

val encoder = EncoderFactory.get().directBinaryEncoder(out, null)
val writer = new GenericDatumWriter[Any](
valueAvroType) // Defining Avro writer for this struct type

writer.write(avroData, encoder) // GenericDataRecord -> bytes
encoder.flush()
out.toByteArray
}

override def decodeTtlExpirationMs(row: Array[Byte]): Option[Long] = {
throw new UnsupportedOperationException
val reader = new GenericDatumReader[Any](valueAvroType)
val decoder = DecoderFactory.get().binaryDecoder(row, 0, row.length, null)
val genericData = reader.read(null, decoder) // bytes -> GenericDataRecord
val internalRow = avroSerde.get.valueDeserializer.deserialize(
genericData).orNull.asInstanceOf[InternalRow] // GenericDataRecord -> InternalRow
val expirationMs = internalRow.getLong(1)
if (expirationMs == -1) {
None
} else {
Some(expirationMs)
}
}

override def isExpired(row: Array[Byte], batchTimestampMs: Long): Boolean = {
throw new UnsupportedOperationException
val expirationMs = decodeTtlExpirationMs(row)
expirationMs.exists(StateTTL.isExpired(_, batchTimestampMs))
}
}

Expand Down Expand Up @@ -436,15 +466,34 @@ class CompositeKeyAvroRowEncoder[K, V](

/** Class for TTL with single key serialization */
class SingleKeyTTLEncoder(
keyExprEnc: ExpressionEncoder[Any]) {
keyExprEnc: ExpressionEncoder[Any],
avroEnc: Option[AvroEncoderSpec] = None) extends Logging {

private lazy val out = new ByteArrayOutputStream
private val ttlKeyProjection = UnsafeProjection.create(
getSingleKeyTTLRowSchema(keyExprEnc.schema))

private val ttlKeyAvroType = SchemaConverters.toAvroType(
getSingleKeyTTLAvroRowSchema)

def encodeTTLRow(expirationMs: Long, groupingKey: UnsafeRow): UnsafeRow = {
ttlKeyProjection.apply(
InternalRow(expirationMs, groupingKey.asInstanceOf[InternalRow]))
}

def encodeTTLRow(expirationMs: Long, groupingKey: Array[Byte]): Array[Byte] = {
val expMsBytes = ByteBuffer.allocate(8).putLong(expirationMs).array()
val internalRow = InternalRow(expMsBytes, groupingKey)
val avroData = avroEnc.get.keySerializer.serialize(internalRow) // InternalRow -> Avro Record
out.reset()
val encoder = EncoderFactory.get().directBinaryEncoder(out, null)
val writer = new GenericDatumWriter[Any](
ttlKeyAvroType) // Defining Avro writer for this struct type

writer.write(avroData, encoder) // GenericDataRecord -> bytes
encoder.flush()
out.toByteArray
}
}

/** Class for TTL with composite key serialization */
Expand Down
Loading
Loading