From ef509c8986dbcc9b37387b0bde56c3d71abb7602 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 4 Oct 2017 19:25:22 -0700 Subject: [PATCH 1/9] Partial implementation --- .../FlatMapGroupsWithStateExec.scala | 134 ++----- .../FlatMapGroupsWithState_StateManager.scala | 331 ++++++++++++++++++ .../FlatMapGroupsWithStateSuite.scala | 130 +++---- 3 files changed, 425 insertions(+), 170 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 8e82cccbc8fa3..4a765add05871 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,10 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} -import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator /** @@ -60,32 +58,14 @@ case class FlatMapGroupsWithStateExec( ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { import GroupStateImpl._ + import FlatMapGroupsWithStateExecHelper._ private val isTimeoutEnabled = timeoutConf != NoTimeout - private val timestampTimeoutAttribute = - AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() - private val stateAttributes: Seq[Attribute] = { - val encSchemaAttribs = stateEncoder.schema.toAttributes - if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs - } - // Get the serializer for the state, taking into account whether we need to save timestamps - private val stateSerializer = { - val encoderSerializer = stateEncoder.namedExpressions - if (isTimeoutEnabled) { - encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) - } else { - encoderSerializer - } - } - // Get the deserializer for the state. Note that this must be done in the driver, as - // resolving and binding of deserializer expressions to the encoded type can be safely done - // only in the driver. - private val stateDeserializer = stateEncoder.resolveAndBind().deserializer - private val watermarkPresent = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false } + private[sql] val stateManager = createStateManager(stateEncoder, isTimeoutEnabled) /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = @@ -125,11 +105,11 @@ case class FlatMapGroupsWithStateExec( child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, groupingAttributes.toStructType, - stateAttributes.toStructType, + stateManager.stateSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val updater = new StateStoreUpdater(store) + val processor = new InputProcessor(store) // If timeout is based on event time, then filter late data based on watermark val filteredIter = watermarkPredicateForData match { @@ -143,7 +123,7 @@ case class FlatMapGroupsWithStateExec( // all the data has been processed. This is to ensure that the timeout information of all // the keys with data is updated before they are processed for timeouts. val outputIterator = - updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() + processor.processNewData(filteredIter) ++ processor.processTimedOutState() // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store @@ -158,7 +138,7 @@ case class FlatMapGroupsWithStateExec( } /** Helper class to update the state store */ - class StateStoreUpdater(store: StateStore) { + class InputProcessor(store: StateStore) { // Converters for translating input keys, values, output data between rows and Java objects private val getKeyObj = @@ -167,14 +147,6 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - // Converters for translating state between rows and Java objects - private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateDeserializer, stateAttributes) - private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // Index of the additional metadata fields in the state row - private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute) - // Metrics private val numUpdatedStateRows = longMetric("numUpdatedStateRows") private val numOutputRows = longMetric("numOutputRows") @@ -183,20 +155,19 @@ case class FlatMapGroupsWithStateExec( * For every group, get the key, values and corresponding state and call the function, * and return an iterator of rows */ - def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) groupedIter.flatMap { case (keyRow, valueRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] callFunctionAndUpdateState( - keyUnsafeRow, + stateManager.getState(store, keyUnsafeRow), valueRowIter, - store.get(keyUnsafeRow), hasTimedOut = false) } } /** Find the groups that have timeout set and are timing out right now, and call the function */ - def updateStateForTimedOutKeys(): Iterator[InternalRow] = { + def processTimedOutState(): Iterator[InternalRow] = { if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get @@ -205,12 +176,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutPairs = store.getRange(None, None).filter { rowPair => - val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) - timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold + val timingOutPairs = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold } - timingOutPairs.flatMap { rowPair => - callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) + timingOutPairs.flatMap { stateData => + callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true) } } else Iterator.empty } @@ -220,22 +190,19 @@ case class FlatMapGroupsWithStateExec( * iterator. Note that the store updating is lazy, that is, the store will be updated only * after the returned iterator is fully consumed. * - * @param keyRow Row representing the key, cannot be null + * @param stateData All the data related to the state to be updated * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty - * @param prevStateRow Row representing the previous state, can be null * @param hasTimedOut Whether this function is being called for a key timeout */ private def callFunctionAndUpdateState( - keyRow: UnsafeRow, + stateData: StateData, valueRowIter: Iterator[InternalRow], - prevStateRow: UnsafeRow, hasTimedOut: Boolean): Iterator[InternalRow] = { - val keyObj = getKeyObj(keyRow) // convert key to objects + val keyObj = getKeyObj(stateData.keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObj = getStateObj(prevStateRow) - val keyedState = GroupStateImpl.createForStreaming( - Option(stateObj), + val groupState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, @@ -243,50 +210,24 @@ case class FlatMapGroupsWithStateExec( watermarkPresent) // Call function, get the returned objects and convert them to rows - val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => + val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => numOutputRows += 1 getOutputRow(obj) } // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - - val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp - // If the state has not yet been set but timeout has been set, then - // we have to generate a row to save the timeout. However, attempting serialize - // null using case class encoder throws - - // java.lang.NullPointerException: Null value appeared in non-nullable field: - // If the schema is inferred from a Scala tuple / case class, or a Java bean, please - // try to use scala.Option[_] or other nullable types. - if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) { - throw new IllegalStateException( - "Cannot set timeout when state is not defined, that is, state has not been" + - "initialized or has been removed") - } - - if (keyedState.hasRemoved) { - store.remove(keyRow) + if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) { + stateManager.removeState(store, stateData.keyRow) numUpdatedStateRows += 1 - } else { - val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow) - val stateRowToWrite = if (keyedState.hasUpdated) { - getStateRow(keyedState.get) - } else { - prevStateRow - } - - val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp - val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged + val currentTimeoutTimestamp = groupState.getTimeoutTimestamp + val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp + val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged if (shouldWriteState) { - if (stateRowToWrite == null) { - // This should never happen because checks in GroupStateImpl should avoid cases - // where empty state would need to be written - throw new IllegalStateException("Attempting to write empty state") - } - setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) - store.put(keyRow, stateRowToWrite) + val updatedStateObj = if (groupState.exists) groupState.get else null + stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp) numUpdatedStateRows += 1 } } @@ -295,28 +236,5 @@ case class FlatMapGroupsWithStateExec( // Return an iterator of rows such that fully consumed, the updated state value will be saved CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } - - /** Returns the state as Java object if defined */ - def getStateObj(stateRow: UnsafeRow): Any = { - if (stateRow != null) getStateObjFromRow(stateRow) else null - } - - /** Returns the row for an updated state */ - def getStateRow(obj: Any): UnsafeRow = { - assert(obj != null) - getStateRowFromObj(obj) - } - - /** Returns the timeout timestamp of a state row is set */ - def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled && stateRow != null) { - stateRow.getLong(timeoutTimestampIndex) - } else NO_TIMESTAMP - } - - /** Set the timestamp in a state row */ - def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { - if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps) - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala new file mode 100644 index 0000000000000..32e280b131031 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, GetStructField, IsNull, Literal, UnsafeRow} +import org.apache.spark.sql.execution.ObjectOperator +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.types.{IntegerType, LongType, StructType} + + +object FlatMapGroupsWithStateExecHelper { + /** + * Class to capture deserialized state and timestamp return by the state manager. + * This is intended for reuse. + */ + case class StateData( + var keyRow: UnsafeRow = null, + var stateRow: UnsafeRow = null, + var stateObj: Any = null, + var timeoutTimestamp: Long = -1) { + + private[FlatMapGroupsWithStateExecHelper] def withNew( + newKeyRow: UnsafeRow, + newStateRow: UnsafeRow, + newStateObj: Any, + newTimeout: Long): this.type = { + keyRow = newKeyRow + stateRow = newStateRow + stateObj = newStateObj + timeoutTimestamp = newTimeout + this + } + } + + sealed trait StateManager extends Serializable { + def stateSchema: StructType + def getState(store: StateStore, keyRow: UnsafeRow): StateData + def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timeoutTimestamp: Long): Unit + def removeState(store: StateStore, keyRow: UnsafeRow): Unit + def getAllState(store: StateStore): Iterator[StateData] + } + + def createStateManager( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean): StateManager = { + new StateManagerImplV2(stateEncoder, shouldStoreTimestamp) + } + + + private abstract class StateManagerImplBase(shouldStoreTimestamp: Boolean) extends StateManager { + + protected def getStateObjFromRow: InternalRow => Any + protected def getStateRowFromObj: Any => UnsafeRow + protected def timeoutTimestampOrdinalInRow: Int + + /** Get deserialized state and corresponding timeout timestamp for a key */ + override def getState(store: StateStore, keyRow: UnsafeRow): StateData = { + val stateRow = store.get(keyRow) + stateDataForGets.withNew( + keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow)) + } + + /** Put state and timeout timestamp for a key */ + override def putState( + store: StateStore, + keyRow: UnsafeRow, + state: Any, + timestamp: Long): Unit = { + val stateRow = getStateRow(state) + setTimestamp(stateRow, timestamp) + store.put(keyRow, stateRow) + } + + override def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { + store.remove(keyRow) + } + + override def getAllState(store: StateStore): Iterator[StateData] = { + val stateDataForGetAllState = StateData() + store.getRange(None, None).map { pair => + stateDataForGetAllState.withNew( + pair.key, pair.value, getStateObjFromRow(pair.value), getTimestamp(pair.value)) + } + } + + private lazy val stateDataForGets = StateData() + + /** Returns the state as Java object if defined */ + private def getStateObj(stateRow: UnsafeRow): Any = { + if (stateRow != null) getStateObjFromRow(stateRow) else null + } + + /** Returns the row for an updated state */ + private def getStateRow(obj: Any): UnsafeRow = { + assert(obj != null) + getStateRowFromObj(obj) + } + + /** Returns the timeout timestamp of a state row is set */ + private def getTimestamp(stateRow: UnsafeRow): Long = { + if (shouldStoreTimestamp && stateRow != null) { + stateRow.getLong(timeoutTimestampOrdinalInRow) + } else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinalInRow, timeoutTimestamps) + } + } + + + private class StateManagerV1(stateEncoder: ExpressionEncoder[Any], shouldStoreTimestamp: Boolean) + extends StateManager { + + private val timestampTimeoutAttribute = + AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() + + private val stateAttributes: Seq[Attribute] = { + val encSchemaAttribs = stateEncoder.schema.toAttributes + if (shouldStoreTimestamp) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs + } + // Get the serializer for the state, taking into account whether we need to save timestamps + private val stateSerializer = { + val encoderSerializer = stateEncoder.namedExpressions + if (shouldStoreTimestamp) { + encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) + } else { + encoderSerializer + } + } + // Get the deserializer for the state. Note that this must be done in the driver, as + // resolving and binding of deserializer expressions to the encoded type can be safely done + // only in the driver. + private val stateDeserializer = stateEncoder.resolveAndBind().deserializer + + // Index of the additional metadata fields in the state row + private val timeoutTimestampOrdinal = stateAttributes.indexOf(timestampTimeoutAttribute) + // Converters for translating state between rows and Java objects + private lazy val getStateObjFromRow = ObjectOperator.deserializeRowToObject( + stateDeserializer, stateAttributes) + private lazy val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + private lazy val stateDataForGets = StateData() + + override def stateSchema: StructType = stateAttributes.toStructType + + override def getState(store: StateStore, keyRow: UnsafeRow): StateData = { + val stateRow = store.get(keyRow) + stateDataForGets.withNew( + keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow)) + } + + override def putState( + store: StateStore, keyRow: UnsafeRow, state: Any, timestamp: Long): Unit = { + // If the state has not yet been set but timeout has been set, then + // we have to generate a row to save the timeout. However, attempting serialize + // null using case class encoder throws - + // java.lang.NullPointerException: Null value appeared in non-nullable field: + // If the schema is inferred from a Scala tuple / case class, or a Java bean, please + // try to use scala.Option[_] or other nullable types. + if (state == null && timestamp != NO_TIMESTAMP) { + throw new IllegalStateException( + "Cannot set timeout when state is not defined, that is, state has not been" + + "initialized or has been removed") + } + val stateRow = getStateRow(state) + setTimestamp(stateRow, timestamp) + store.put(keyRow, stateRow) + } + + override def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { + store.remove(keyRow) + } + + override def getAllState(store: StateStore): Iterator[StateData] = { + val stateDataForGetAllState = StateData() + store.getRange(None, None).map { pair => + stateDataForGetAllState.withNew( + pair.key, pair.value, getStateObjFromRow(pair.value), getTimestamp(pair.value)) + } + } + + /** Returns the state as Java object if defined */ + private def getStateObj(stateRow: UnsafeRow): Any = { + if (stateRow != null) getStateObjFromRow(stateRow) else null + } + + /** Returns the row for an updated state */ + private def getStateRow(obj: Any): UnsafeRow = { + assert(obj != null) + getStateRowFromObj(obj) + } + + /** Returns the timeout timestamp of a state row is set */ + def getTimestamp(stateRow: UnsafeRow): Long = { + if (shouldStoreTimestamp && stateRow != null) { + stateRow.getLong(timeoutTimestampOrdinal) + } else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinal, timeoutTimestamps) + } + } + + + /** + * Class to serialize/write/read/deserialize state for + * [[org.apache.spark.sql.execution.streaming.FlatMapGroupsWithStateExec]]. + */ + class StateManagerImplV2(stateEncoder: ExpressionEncoder[Any], shouldStoreTimestamp: Boolean) + extends StateManager { + + /** Schema of the state rows saved in the state store */ + val stateSchema = { + val schema = new StructType().add("groupState", stateEncoder.schema, nullable = true) + if (shouldStoreTimestamp) schema.add("timeoutTimestamp", LongType) else schema + } + + /** Get deserialized state and corresponding timeout timestamp for a key */ + def getState(store: StateStore, keyRow: UnsafeRow): StateData = { + val stateRow = store.get(keyRow) + stateDataForGets.withNew( + keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow)) + } + + /** Put state and timeout timestamp for a key */ + def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timestamp: Long): Unit = { + val stateRow = getStateRow(state) + setTimestamp(stateRow, timestamp) + store.put(keyRow, stateRow) + } + + /** Removed all information related to a key */ + def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { + store.remove(keyRow) + } + + /** Get all the keys and corresponding state rows in the state store */ + def getAllState(store: StateStore): Iterator[StateData] = { + val stateDataForGetAllState = StateData() + store.getRange(None, None).map { pair => + stateDataForGetAllState.withNew( + pair.key, pair.value, getStateObjFromRow(pair.value), getTimestamp(pair.value)) + } + } + + // Ordinals of the information stored in the state row + private lazy val nestedStateOrdinal = 0 + private lazy val timeoutTimestampOrdinal = 1 + + // Get the serializer for the state, taking into account whether we need to save timestamps + private val stateSerializer = { + val nestedStateExpr = CreateNamedStruct( + stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) + if (shouldStoreTimestamp) { + Seq(nestedStateExpr, Literal(GroupStateImpl.NO_TIMESTAMP)) + } else { + Seq(nestedStateExpr) + } + } + + // Get the deserializer for the state. Note that this must be done in the driver, as + // resolving and binding of deserializer expressions to the encoded type can be safely done + // only in the driver. + private val stateDeserializer = { + val boundRefToNestedState = BoundReference(nestedStateOrdinal, stateEncoder.schema, true) + val deser = stateEncoder.resolveAndBind().deserializer.transformUp { + case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) + } + CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser) + } + + // Converters for translating state between rows and Java objects + private lazy val getStateObjFromRow = ObjectOperator.deserializeRowToObject( + stateDeserializer, stateSchema.toAttributes) + private lazy val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + // Reusable instance for returning state information + private lazy val stateDataForGets = StateData() + + /** Returns the state as Java object if defined */ + private def getStateObj(stateRow: UnsafeRow): Any = { + if (stateRow == null) null + else getStateObjFromRow(stateRow) + } + + /** Returns the row for an updated state */ + private def getStateRow(obj: Any): UnsafeRow = { + val row = getStateRowFromObj(obj) + if (obj == null) { + row.setNullAt(nestedStateOrdinal) + } + row + } + + /** Returns the timeout timestamp of a state row is set */ + private def getTimestamp(stateRow: UnsafeRow): Long = { + if (shouldStoreTimestamp && stateRow != null) { + stateRow.getLong(timeoutTimestampOrdinal) + } else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinal, timeoutTimestamps) + } + } + +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 988c8e6753e25..f8d138ec129e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -359,13 +359,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } } - // Values used for testing StateStoreUpdater + // Values used for testing InputProcessor val currentBatchTimestamp = 1000 val currentBatchWatermark = 1000 val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout + // Tests for InputProcessor.processNewData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" val testName = s"NoTimeout - $priorStateStr - " @@ -396,7 +396,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = None) // should be removed } - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout + // Tests for InputProcessor.processTimedOutState() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { var testName = "" @@ -443,6 +443,18 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = None) // state should be removed } + // Tests with ProcessingTimeTimeout + if (priorState == None) { + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated without initializing state", + stateUpdates = state => { state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + } + testStateUpdateWithData( s"ProcessingTimeTimeout - $testName - state and timeout duration updated", stateUpdates = @@ -453,10 +465,36 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = Some(5), // state should change expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated after state removed", + stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + + // Tests with EventTimeTimeout + + if (priorState == None) { + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { + state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } + testStateUpdateWithData( s"EventTimeTimeout - $testName - state and timeout timestamp updated", stateUpdates = - (state: GroupState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, + (state: GroupState[Int]) => { + state.update(5); state.setTimeoutTimestamp(5000) + }, timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -475,50 +513,23 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = Some(5), // state should change - expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - } - } - - // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(), - // Try to remove these cases in the future - for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { - val testName = - if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout" - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { + state.remove(); state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } } - // Tests for StateStoreUpdater.updateStateForTimedOutKeys() + // Tests for InputProcessor.processTimedOutState() val preTimeoutState = Some(5) for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { testStateUpdateWithTimeout( @@ -1032,7 +1043,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state } - test(s"StateStoreUpdater - updates with data - $testName") { + test(s"InputProcessor - process new data - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === false, "hasTimedOut not false") assert(values.nonEmpty, "Some value is expected") @@ -1054,7 +1065,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState: Option[Int], expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { - test(s"StateStoreUpdater - updates for timeout - $testName") { + test(s"InputProcessor - process timed out state - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === true, "hasTimedOut not true") assert(values.isEmpty, "values not empty") @@ -1081,21 +1092,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( mapGroupsFunc, timeoutConf, currentBatchTimestamp) - val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) + val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store) + val stateManager = mapGroupsSparkPlan.stateManager val key = intToRow(0) // Prepare store with prior state configs - if (priorState.nonEmpty) { - val row = updater.getStateRow(priorState.get) - updater.setTimeoutTimestamp(row, priorTimeoutTimestamp) - store.put(key.copy(), row.copy()) + if (priorState.nonEmpty || priorTimeoutTimestamp != NO_TIMESTAMP) { + stateManager.putState(store, key, priorState.orNull, priorTimeoutTimestamp) } // Call updating function to update state store def callFunction() = { val returnedIter = if (testTimeoutUpdates) { - updater.updateStateForTimedOutKeys() + inputProcessor.processTimedOutState() } else { - updater.updateStateForKeysWithData(Iterator(key)) + inputProcessor.processNewData(Iterator(key)) } returnedIter.size // consume the iterator to force state updates } @@ -1106,15 +1116,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } else { // Call function to update and verify updated state in store callFunction() - val updatedStateRow = store.get(key) - assert( - Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, + val updatedState = stateManager.getState(store, key) + assert(Option(updatedState.stateObj).map(_.toString.toInt) === expectedState, "final state not as expected") - if (updatedStateRow != null) { - assert( - updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, - "final timeout timestamp not as expected") - } + assert(updatedState.timeoutTimestamp === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") } } From 976a7ea3d5d528e6f1091c696c7f6e865027ee23 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 9 Jul 2018 04:05:10 -0700 Subject: [PATCH 2/9] Fixed and added tests --- .../apache/spark/sql/types/ObjectType.scala | 2 +- .../FlatMapGroupsWithStateExec.scala | 2 +- .../FlatMapGroupsWithStateExecHelper.scala | 232 ++++++++++++ .../FlatMapGroupsWithState_StateManager.scala | 331 ------------------ ...latMapGroupsWithStateExecHelperSuite.scala | 219 ++++++++++++ 5 files changed, 453 insertions(+), 333 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 2d49fe076786a..820f40f5bd10b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -43,7 +43,7 @@ case class ObjectType(cls: Class[_]) extends DataType { def asNullable: DataType = this - override def simpleString: String = cls.getName + override def simpleString: String = s"Object[${cls.getName}]" override def acceptsType(other: DataType): Boolean = other match { case ObjectType(otherCls) => cls.isAssignableFrom(otherCls) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 4a765add05871..61640edb54941 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -65,7 +65,7 @@ case class FlatMapGroupsWithStateExec( case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false } - private[sql] val stateManager = createStateManager(stateEncoder, isTimeoutEnabled) + private[sql] val stateManager = createStateManager(stateEncoder, isTimeoutEnabled, 2) /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala new file mode 100644 index 0000000000000..27f0e35795171 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, GenericInternalRow, GetStructField, If, IsNull, Literal, SpecificInternalRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.ObjectOperator +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.types._ + + +object FlatMapGroupsWithStateExecHelper { + /** + * Class to capture deserialized state and timestamp return by the state manager. + * This is intended for reuse. + */ + case class StateData( + var keyRow: UnsafeRow = null, + var stateRow: UnsafeRow = null, + var stateObj: Any = null, + var timeoutTimestamp: Long = -1) { + + private[FlatMapGroupsWithStateExecHelper] def withNew( + newKeyRow: UnsafeRow, + newStateRow: UnsafeRow, + newStateObj: Any, + newTimeout: Long): this.type = { + keyRow = newKeyRow + stateRow = newStateRow + stateObj = newStateObj + timeoutTimestamp = newTimeout + this + } + } + + sealed trait StateManager extends Serializable { + def stateSchema: StructType + def getState(store: StateStore, keyRow: UnsafeRow): StateData + def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timeoutTimestamp: Long): Unit + def removeState(store: StateStore, keyRow: UnsafeRow): Unit + def getAllState(store: StateStore): Iterator[StateData] + } + + def createStateManager( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean, + version: Int): StateManager = { + version match { + case 1 => new StateManagerImplV1(stateEncoder, shouldStoreTimestamp) + case 2 => new StateManagerImplV2(stateEncoder, shouldStoreTimestamp) + case _ => throw new IllegalArgumentException(s"Version $version") + } + } + + // =============================================================================================== + // =========================== Private implementations of StateManager =========================== + // =============================================================================================== + + private abstract class StateManagerImplBase(shouldStoreTimestamp: Boolean) extends StateManager { + + protected def stateRowToObject(row: UnsafeRow): Any + protected def stateObjectToRow(state: Any): UnsafeRow + protected def timeoutTimestampOrdinalInRow: Int + + /** Get deserialized state and corresponding timeout timestamp for a key */ + override def getState(store: StateStore, keyRow: UnsafeRow): StateData = { + val stateRow = store.get(keyRow) + stateDataForGets.withNew(keyRow, stateRow, stateRowToObject(stateRow), getTimestamp(stateRow)) + } + + /** Put state and timeout timestamp for a key */ + override def putState(store: StateStore, key: UnsafeRow, state: Any, timestamp: Long): Unit = { + val stateRow = stateObjectToRow(state) + setTimestamp(stateRow, timestamp) + store.put(key, stateRow) + } + + override def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { + store.remove(keyRow) + } + + override def getAllState(store: StateStore): Iterator[StateData] = { + val stateData = StateData() + store.getRange(None, None).map { p => + stateData.withNew(p.key, p.value, stateRowToObject(p.value), getTimestamp(p.value)) + } + } + + private lazy val stateDataForGets = StateData() + + /** Returns the timeout timestamp of a state row is set */ + private def getTimestamp(stateRow: UnsafeRow): Long = { + if (shouldStoreTimestamp && stateRow != null) { + stateRow.getLong(timeoutTimestampOrdinalInRow) + } else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinalInRow, timeoutTimestamps) + } + } + + + private class StateManagerImplV1( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { + + private val timestampTimeoutAttribute = + AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() + + private val stateAttributes: Seq[Attribute] = { + val encSchemaAttribs = stateEncoder.schema.toAttributes + if (shouldStoreTimestamp) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs + } + + override val stateSchema: StructType = stateAttributes.toStructType + + override protected val timeoutTimestampOrdinalInRow: Int = { + stateAttributes.indexOf(timestampTimeoutAttribute) + } + + private val stateSerializerExprs = { + val encoderSerializer = stateEncoder.namedExpressions + if (shouldStoreTimestamp) { + encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) + } else { + encoderSerializer + } + } + + private val stateDeserializerExpr = { + // Note that this must be done in the driver, as resolving and binding of deserializer + // expressions to the encoded type can be safely done only in the driver. + stateEncoder.resolveAndBind().deserializer + } + + private lazy val stateSerializerFunc = ObjectOperator.serializeObjectToRow(stateSerializerExprs) + + private lazy val stateDeserializerFunc = { + ObjectOperator.deserializeRowToObject(stateDeserializerExpr, stateSchema.toAttributes) + } + + override protected def stateRowToObject(row: UnsafeRow): Any = { + if (row != null) stateDeserializerFunc(row) else null + } + + override protected def stateObjectToRow(obj: Any): UnsafeRow = { + require(obj != null, "State object cannot be null") + stateSerializerFunc(obj) + } + } + + + private class StateManagerImplV2( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { + + /** Schema of the state rows saved in the state store */ + override val stateSchema: StructType = { + var schema = new StructType().add("groupState", stateEncoder.schema, nullable = true) + if (shouldStoreTimestamp) schema = schema.add("timeoutTimestamp", LongType, nullable = false) + schema + } + + // Ordinals of the information stored in the state row + private val nestedStateOrdinal = 0 + override protected val timeoutTimestampOrdinalInRow = 1 + + private val stateSerializerExprs = { + val boundRefToSpecificInternalRow = BoundReference( + 0, stateEncoder.serializer.head.collect { case b: BoundReference => b.dataType }.head, true) + + val nestedStateSerExpr = + CreateNamedStruct(stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) + + val nullSafeNestedStateSerExpr = { + val nullLiteral = Literal(null, nestedStateSerExpr.dataType) + CaseWhen(Seq(IsNull(boundRefToSpecificInternalRow) -> nullLiteral), nestedStateSerExpr) + } + + if (shouldStoreTimestamp) { + Seq(nullSafeNestedStateSerExpr, Literal(GroupStateImpl.NO_TIMESTAMP)) + } else { + Seq(nullSafeNestedStateSerExpr) + } + } + + private val stateDeserializerExpr = { + // Note that this must be done in the driver, as resolving and binding of deserializer + // expressions to the encoded type can be safely done only in the driver. + val boundRefToNestedState = + BoundReference(nestedStateOrdinal, stateEncoder.schema, nullable = true) + val deserExpr = stateEncoder.resolveAndBind().deserializer.transformUp { + case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) + } + CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deserExpr) + } + + private lazy val stateSerializerFunc = ObjectOperator.serializeObjectToRow(stateSerializerExprs) + + private lazy val stateDeserializerFunc = { + ObjectOperator.deserializeRowToObject(stateDeserializerExpr, stateSchema.toAttributes) + } + + override protected def stateRowToObject(row: UnsafeRow): Any = { + if (row != null) stateDeserializerFunc(row) else null + } + + override protected def stateObjectToRow(obj: Any): UnsafeRow = { + stateSerializerFunc(obj) + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala deleted file mode 100644 index 32e280b131031..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala +++ /dev/null @@ -1,331 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.state - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, GetStructField, IsNull, Literal, UnsafeRow} -import org.apache.spark.sql.execution.ObjectOperator -import org.apache.spark.sql.execution.streaming.GroupStateImpl -import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP -import org.apache.spark.sql.types.{IntegerType, LongType, StructType} - - -object FlatMapGroupsWithStateExecHelper { - /** - * Class to capture deserialized state and timestamp return by the state manager. - * This is intended for reuse. - */ - case class StateData( - var keyRow: UnsafeRow = null, - var stateRow: UnsafeRow = null, - var stateObj: Any = null, - var timeoutTimestamp: Long = -1) { - - private[FlatMapGroupsWithStateExecHelper] def withNew( - newKeyRow: UnsafeRow, - newStateRow: UnsafeRow, - newStateObj: Any, - newTimeout: Long): this.type = { - keyRow = newKeyRow - stateRow = newStateRow - stateObj = newStateObj - timeoutTimestamp = newTimeout - this - } - } - - sealed trait StateManager extends Serializable { - def stateSchema: StructType - def getState(store: StateStore, keyRow: UnsafeRow): StateData - def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timeoutTimestamp: Long): Unit - def removeState(store: StateStore, keyRow: UnsafeRow): Unit - def getAllState(store: StateStore): Iterator[StateData] - } - - def createStateManager( - stateEncoder: ExpressionEncoder[Any], - shouldStoreTimestamp: Boolean): StateManager = { - new StateManagerImplV2(stateEncoder, shouldStoreTimestamp) - } - - - private abstract class StateManagerImplBase(shouldStoreTimestamp: Boolean) extends StateManager { - - protected def getStateObjFromRow: InternalRow => Any - protected def getStateRowFromObj: Any => UnsafeRow - protected def timeoutTimestampOrdinalInRow: Int - - /** Get deserialized state and corresponding timeout timestamp for a key */ - override def getState(store: StateStore, keyRow: UnsafeRow): StateData = { - val stateRow = store.get(keyRow) - stateDataForGets.withNew( - keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow)) - } - - /** Put state and timeout timestamp for a key */ - override def putState( - store: StateStore, - keyRow: UnsafeRow, - state: Any, - timestamp: Long): Unit = { - val stateRow = getStateRow(state) - setTimestamp(stateRow, timestamp) - store.put(keyRow, stateRow) - } - - override def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { - store.remove(keyRow) - } - - override def getAllState(store: StateStore): Iterator[StateData] = { - val stateDataForGetAllState = StateData() - store.getRange(None, None).map { pair => - stateDataForGetAllState.withNew( - pair.key, pair.value, getStateObjFromRow(pair.value), getTimestamp(pair.value)) - } - } - - private lazy val stateDataForGets = StateData() - - /** Returns the state as Java object if defined */ - private def getStateObj(stateRow: UnsafeRow): Any = { - if (stateRow != null) getStateObjFromRow(stateRow) else null - } - - /** Returns the row for an updated state */ - private def getStateRow(obj: Any): UnsafeRow = { - assert(obj != null) - getStateRowFromObj(obj) - } - - /** Returns the timeout timestamp of a state row is set */ - private def getTimestamp(stateRow: UnsafeRow): Long = { - if (shouldStoreTimestamp && stateRow != null) { - stateRow.getLong(timeoutTimestampOrdinalInRow) - } else NO_TIMESTAMP - } - - /** Set the timestamp in a state row */ - private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { - if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinalInRow, timeoutTimestamps) - } - } - - - private class StateManagerV1(stateEncoder: ExpressionEncoder[Any], shouldStoreTimestamp: Boolean) - extends StateManager { - - private val timestampTimeoutAttribute = - AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() - - private val stateAttributes: Seq[Attribute] = { - val encSchemaAttribs = stateEncoder.schema.toAttributes - if (shouldStoreTimestamp) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs - } - // Get the serializer for the state, taking into account whether we need to save timestamps - private val stateSerializer = { - val encoderSerializer = stateEncoder.namedExpressions - if (shouldStoreTimestamp) { - encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) - } else { - encoderSerializer - } - } - // Get the deserializer for the state. Note that this must be done in the driver, as - // resolving and binding of deserializer expressions to the encoded type can be safely done - // only in the driver. - private val stateDeserializer = stateEncoder.resolveAndBind().deserializer - - // Index of the additional metadata fields in the state row - private val timeoutTimestampOrdinal = stateAttributes.indexOf(timestampTimeoutAttribute) - // Converters for translating state between rows and Java objects - private lazy val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateDeserializer, stateAttributes) - private lazy val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - private lazy val stateDataForGets = StateData() - - override def stateSchema: StructType = stateAttributes.toStructType - - override def getState(store: StateStore, keyRow: UnsafeRow): StateData = { - val stateRow = store.get(keyRow) - stateDataForGets.withNew( - keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow)) - } - - override def putState( - store: StateStore, keyRow: UnsafeRow, state: Any, timestamp: Long): Unit = { - // If the state has not yet been set but timeout has been set, then - // we have to generate a row to save the timeout. However, attempting serialize - // null using case class encoder throws - - // java.lang.NullPointerException: Null value appeared in non-nullable field: - // If the schema is inferred from a Scala tuple / case class, or a Java bean, please - // try to use scala.Option[_] or other nullable types. - if (state == null && timestamp != NO_TIMESTAMP) { - throw new IllegalStateException( - "Cannot set timeout when state is not defined, that is, state has not been" + - "initialized or has been removed") - } - val stateRow = getStateRow(state) - setTimestamp(stateRow, timestamp) - store.put(keyRow, stateRow) - } - - override def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { - store.remove(keyRow) - } - - override def getAllState(store: StateStore): Iterator[StateData] = { - val stateDataForGetAllState = StateData() - store.getRange(None, None).map { pair => - stateDataForGetAllState.withNew( - pair.key, pair.value, getStateObjFromRow(pair.value), getTimestamp(pair.value)) - } - } - - /** Returns the state as Java object if defined */ - private def getStateObj(stateRow: UnsafeRow): Any = { - if (stateRow != null) getStateObjFromRow(stateRow) else null - } - - /** Returns the row for an updated state */ - private def getStateRow(obj: Any): UnsafeRow = { - assert(obj != null) - getStateRowFromObj(obj) - } - - /** Returns the timeout timestamp of a state row is set */ - def getTimestamp(stateRow: UnsafeRow): Long = { - if (shouldStoreTimestamp && stateRow != null) { - stateRow.getLong(timeoutTimestampOrdinal) - } else NO_TIMESTAMP - } - - /** Set the timestamp in a state row */ - def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { - if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinal, timeoutTimestamps) - } - } - - - /** - * Class to serialize/write/read/deserialize state for - * [[org.apache.spark.sql.execution.streaming.FlatMapGroupsWithStateExec]]. - */ - class StateManagerImplV2(stateEncoder: ExpressionEncoder[Any], shouldStoreTimestamp: Boolean) - extends StateManager { - - /** Schema of the state rows saved in the state store */ - val stateSchema = { - val schema = new StructType().add("groupState", stateEncoder.schema, nullable = true) - if (shouldStoreTimestamp) schema.add("timeoutTimestamp", LongType) else schema - } - - /** Get deserialized state and corresponding timeout timestamp for a key */ - def getState(store: StateStore, keyRow: UnsafeRow): StateData = { - val stateRow = store.get(keyRow) - stateDataForGets.withNew( - keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow)) - } - - /** Put state and timeout timestamp for a key */ - def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timestamp: Long): Unit = { - val stateRow = getStateRow(state) - setTimestamp(stateRow, timestamp) - store.put(keyRow, stateRow) - } - - /** Removed all information related to a key */ - def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { - store.remove(keyRow) - } - - /** Get all the keys and corresponding state rows in the state store */ - def getAllState(store: StateStore): Iterator[StateData] = { - val stateDataForGetAllState = StateData() - store.getRange(None, None).map { pair => - stateDataForGetAllState.withNew( - pair.key, pair.value, getStateObjFromRow(pair.value), getTimestamp(pair.value)) - } - } - - // Ordinals of the information stored in the state row - private lazy val nestedStateOrdinal = 0 - private lazy val timeoutTimestampOrdinal = 1 - - // Get the serializer for the state, taking into account whether we need to save timestamps - private val stateSerializer = { - val nestedStateExpr = CreateNamedStruct( - stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) - if (shouldStoreTimestamp) { - Seq(nestedStateExpr, Literal(GroupStateImpl.NO_TIMESTAMP)) - } else { - Seq(nestedStateExpr) - } - } - - // Get the deserializer for the state. Note that this must be done in the driver, as - // resolving and binding of deserializer expressions to the encoded type can be safely done - // only in the driver. - private val stateDeserializer = { - val boundRefToNestedState = BoundReference(nestedStateOrdinal, stateEncoder.schema, true) - val deser = stateEncoder.resolveAndBind().deserializer.transformUp { - case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) - } - CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser) - } - - // Converters for translating state between rows and Java objects - private lazy val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateDeserializer, stateSchema.toAttributes) - private lazy val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // Reusable instance for returning state information - private lazy val stateDataForGets = StateData() - - /** Returns the state as Java object if defined */ - private def getStateObj(stateRow: UnsafeRow): Any = { - if (stateRow == null) null - else getStateObjFromRow(stateRow) - } - - /** Returns the row for an updated state */ - private def getStateRow(obj: Any): UnsafeRow = { - val row = getStateRowFromObj(obj) - if (obj == null) { - row.setNullAt(nestedStateOrdinal) - } - row - } - - /** Returns the timeout timestamp of a state row is set */ - private def getTimestamp(stateRow: UnsafeRow): Long = { - if (shouldStoreTimestamp && stateRow != null) { - stateRow.getLong(timeoutTimestampOrdinal) - } else NO_TIMESTAMP - } - - /** Set the timestamp in a state row */ - private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { - if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinal, timeoutTimestamps) - } - } - -} - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala new file mode 100644 index 0000000000000..6b9b645ae722d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.sql.{Encoder, QueryTest} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.GroupStateImpl._ +import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite._ +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types._ + + +class FlatMapGroupsWithStateExecHelperSuite extends StreamTest { + + import testImplicits._ + import FlatMapGroupsWithStateExecHelper._ + + // ============================ StateManagerImplV1 ============================ + + + test(s"StateManager v1 - primitive type - without timestamp") { + val schema = new StructType().add("value", IntegerType, nullable = false) + testStateManagerWithoutTimestamp[Int](version = 1, schema, Seq(0, 10)) + } + + test(s"StateManager v1 - primitive type - with timestamp") { + val schema = new StructType() + .add("value", IntegerType, nullable = false) + .add("timeoutTimestamp", IntegerType, nullable = false) + testStateManagerWithTimestamp[Int](version = 1, schema, Seq(0, 10)) + } + + test(s"StateManager v1 - nested type - without timestamp") { + val schema = StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType)) + )) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null)) + + testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, testValues) + + // Verify the limitation of v1 with null + intercept[Exception] { + testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, Seq(null)) + } + } + + test(s"StateManager v1 - nested type - with timestamp") { + val schema = StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType)) + )), + StructField("timeoutTimestamp", IntegerType, nullable = false) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null)) + + testStateManagerWithTimestamp[NestedStruct](version = 1, schema, testValues) + + // Verify the limitation of v1 with null + intercept[Exception] { + testStateManagerWithTimestamp[NestedStruct](version = 1, schema, Seq(null)) + } + } + + // ============================ StateManagerImplV2 ============================ + + test(s"StateManager v2 - primitive type - without timestamp") { + val schema = new StructType() + .add("groupState", new StructType().add("value", IntegerType, nullable = false)) + testStateManagerWithoutTimestamp[Int](version = 2, schema, Seq(0, 10)) + } + + test(s"StateManager v2 - primitive type - with timestamp") { + val schema = new StructType() + .add("groupState", new StructType().add("value", IntegerType, nullable = false)) + .add("timeoutTimestamp", LongType, nullable = false) + testStateManagerWithTimestamp[Int](version = 2, schema, Seq(0, 10)) + } + + test(s"StateManager v2 - nested type - without timestamp") { + val schema = StructType(Seq( + StructField("groupState", StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType) + ))) + ))) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null), + null) + + testStateManagerWithoutTimestamp[NestedStruct](version = 2, schema, testValues) + } + + test(s"StateManager v2 - nested type - with timestamp") { + val schema = StructType(Seq( + StructField("groupState", StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType) + ))) + ))), + StructField("timeoutTimestamp", LongType, nullable = false) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null), + null) + + testStateManagerWithTimestamp[NestedStruct](version = 2, schema, testValues) + } + + + def testStateManagerWithoutTimestamp[T: Encoder]( + version: Int, + expectedStateSchema: StructType, + testValues: Seq[T]): Unit = { + val stateManager = newStateManager[T](version, withTimestamp = false) + assert(stateManager.stateSchema === expectedStateSchema) + testStateManager(stateManager, testValues, NO_TIMESTAMP) + } + + def testStateManagerWithTimestamp[T: Encoder]( + version: Int, + expectedStateSchema: StructType, + testValues: Seq[T]): Unit = { + val stateManager = newStateManager[T](version, withTimestamp = true) + assert(stateManager.stateSchema === expectedStateSchema) + for (timestamp <- Seq(NO_TIMESTAMP, 1000)) { + testStateManager(stateManager, testValues, timestamp) + } + } + + private def testStateManager[T: Encoder]( + stateManager: StateManager, + values: Seq[T], + timestamp: Long): Unit = { + val keys = (1 to values.size).map(_ => newKey()) + val store = new MemoryStateStore() + + // Test stateManager.getState(), putState(), removeState() + keys.zip(values).foreach { case (key, value) => + try { + stateManager.putState(store, key, value, timestamp) + val data = stateManager.getState(store, key) + assert(data.stateObj == value) + assert(data.timeoutTimestamp === timestamp) + stateManager.removeState(store, key) + assert(stateManager.getState(store, key).stateObj == null) + } catch { + case e: Throwable => + fail(s"put/get/remove test with '$value' failed", e) + } + } + + // Test stateManager.getAllState() + for (i <- keys.indices) { + stateManager.putState(store, keys(i), values(i), timestamp) + } + val allData = stateManager.getAllState(store).map(_.copy()).toArray + assert(allData.map(_.timeoutTimestamp).toSet == Set(timestamp)) + assert(allData.map(_.stateObj).toSet == values.toSet) + } + + private def newStateManager[T: Encoder](version: Int, withTimestamp: Boolean): StateManager = { + FlatMapGroupsWithStateExecHelper.createStateManager( + implicitly[Encoder[T]].asInstanceOf[ExpressionEncoder[Any]], + withTimestamp, + version) + } + + private val proj = UnsafeProjection.create(Array[DataType](IntegerType)) + private val keyCounter = new AtomicInteger(0) + private def newKey(): UnsafeRow = { + proj.apply(new GenericInternalRow(Array[Any](keyCounter.getAndDecrement()))).copy() + } +} + +case class Struct(d: Double, str: String) +case class NestedStruct(i: Int, nested: Struct) From cfc3f68aabeb4e83bfe8131e93e5f0133fba4869 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 9 Jul 2018 04:19:01 -0700 Subject: [PATCH 3/9] Refactored --- .../FlatMapGroupsWithStateExecHelper.scala | 72 ++++++++----------- ...latMapGroupsWithStateExecHelperSuite.scala | 5 +- 2 files changed, 31 insertions(+), 46 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala index 27f0e35795171..508e29d4b2196 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, GenericInternalRow, GetStructField, If, IsNull, Literal, SpecificInternalRow, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, Expression, GenericInternalRow, GetStructField, If, IsNull, Literal, SpecificInternalRow, UnsafeRow} import org.apache.spark.sql.execution.ObjectOperator import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP @@ -27,6 +26,9 @@ import org.apache.spark.sql.types._ object FlatMapGroupsWithStateExecHelper { + + val DEFAULT_STATE_MANAGER_VERSION = 2 + /** * Class to capture deserialized state and timestamp return by the state manager. * This is intended for reuse. @@ -75,19 +77,19 @@ object FlatMapGroupsWithStateExecHelper { private abstract class StateManagerImplBase(shouldStoreTimestamp: Boolean) extends StateManager { - protected def stateRowToObject(row: UnsafeRow): Any - protected def stateObjectToRow(state: Any): UnsafeRow + protected def stateSerializerExprs: Seq[Expression] + protected def stateDeserializerExpr: Expression protected def timeoutTimestampOrdinalInRow: Int /** Get deserialized state and corresponding timeout timestamp for a key */ override def getState(store: StateStore, keyRow: UnsafeRow): StateData = { val stateRow = store.get(keyRow) - stateDataForGets.withNew(keyRow, stateRow, stateRowToObject(stateRow), getTimestamp(stateRow)) + stateDataForGets.withNew(keyRow, stateRow, getStateObject(stateRow), getTimestamp(stateRow)) } /** Put state and timeout timestamp for a key */ override def putState(store: StateStore, key: UnsafeRow, state: Any, timestamp: Long): Unit = { - val stateRow = stateObjectToRow(state) + val stateRow = getStateRow(state) setTimestamp(stateRow, timestamp) store.put(key, stateRow) } @@ -99,12 +101,24 @@ object FlatMapGroupsWithStateExecHelper { override def getAllState(store: StateStore): Iterator[StateData] = { val stateData = StateData() store.getRange(None, None).map { p => - stateData.withNew(p.key, p.value, stateRowToObject(p.value), getTimestamp(p.value)) + stateData.withNew(p.key, p.value, getStateObject(p.value), getTimestamp(p.value)) } } + private lazy val stateSerializerFunc = ObjectOperator.serializeObjectToRow(stateSerializerExprs) + private lazy val stateDeserializerFunc = { + ObjectOperator.deserializeRowToObject(stateDeserializerExpr, stateSchema.toAttributes) + } private lazy val stateDataForGets = StateData() + protected def getStateObject(row: UnsafeRow): Any = { + if (row != null) stateDeserializerFunc(row) else null + } + + protected def getStateRow(obj: Any): UnsafeRow = { + stateSerializerFunc(obj) + } + /** Returns the timeout timestamp of a state row is set */ private def getTimestamp(stateRow: UnsafeRow): Long = { if (shouldStoreTimestamp && stateRow != null) { @@ -133,11 +147,11 @@ object FlatMapGroupsWithStateExecHelper { override val stateSchema: StructType = stateAttributes.toStructType - override protected val timeoutTimestampOrdinalInRow: Int = { + override val timeoutTimestampOrdinalInRow: Int = { stateAttributes.indexOf(timestampTimeoutAttribute) } - private val stateSerializerExprs = { + override val stateSerializerExprs: Seq[Expression] = { val encoderSerializer = stateEncoder.namedExpressions if (shouldStoreTimestamp) { encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) @@ -146,25 +160,11 @@ object FlatMapGroupsWithStateExecHelper { } } - private val stateDeserializerExpr = { - // Note that this must be done in the driver, as resolving and binding of deserializer - // expressions to the encoded type can be safely done only in the driver. - stateEncoder.resolveAndBind().deserializer - } + override val stateDeserializerExpr: Expression = stateEncoder.resolveAndBind().deserializer - private lazy val stateSerializerFunc = ObjectOperator.serializeObjectToRow(stateSerializerExprs) - - private lazy val stateDeserializerFunc = { - ObjectOperator.deserializeRowToObject(stateDeserializerExpr, stateSchema.toAttributes) - } - - override protected def stateRowToObject(row: UnsafeRow): Any = { - if (row != null) stateDeserializerFunc(row) else null - } - - override protected def stateObjectToRow(obj: Any): UnsafeRow = { + override protected def getStateRow(obj: Any): UnsafeRow = { require(obj != null, "State object cannot be null") - stateSerializerFunc(obj) + super.getStateRow(obj) } } @@ -182,9 +182,9 @@ object FlatMapGroupsWithStateExecHelper { // Ordinals of the information stored in the state row private val nestedStateOrdinal = 0 - override protected val timeoutTimestampOrdinalInRow = 1 + override val timeoutTimestampOrdinalInRow = 1 - private val stateSerializerExprs = { + override val stateSerializerExprs: Seq[Expression] = { val boundRefToSpecificInternalRow = BoundReference( 0, stateEncoder.serializer.head.collect { case b: BoundReference => b.dataType }.head, true) @@ -203,7 +203,7 @@ object FlatMapGroupsWithStateExecHelper { } } - private val stateDeserializerExpr = { + override val stateDeserializerExpr: Expression = { // Note that this must be done in the driver, as resolving and binding of deserializer // expressions to the encoded type can be safely done only in the driver. val boundRefToNestedState = @@ -213,20 +213,6 @@ object FlatMapGroupsWithStateExecHelper { } CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deserExpr) } - - private lazy val stateSerializerFunc = ObjectOperator.serializeObjectToRow(stateSerializerExprs) - - private lazy val stateDeserializerFunc = { - ObjectOperator.deserializeRowToObject(stateDeserializerExpr, stateSchema.toAttributes) - } - - override protected def stateRowToObject(row: UnsafeRow): Any = { - if (row != null) stateDeserializerFunc(row) else null - } - - override protected def stateObjectToRow(obj: Any): UnsafeRow = { - stateSerializerFunc(obj) - } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala index 6b9b645ae722d..3165cb34fce65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala @@ -35,7 +35,6 @@ class FlatMapGroupsWithStateExecHelperSuite extends StreamTest { // ============================ StateManagerImplV1 ============================ - test(s"StateManager v1 - primitive type - without timestamp") { val schema = new StructType().add("value", IntegerType, nullable = false) testStateManagerWithoutTimestamp[Int](version = 1, schema, Seq(0, 10)) @@ -64,7 +63,7 @@ class FlatMapGroupsWithStateExecHelperSuite extends StreamTest { testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, testValues) - // Verify the limitation of v1 with null + // Verify the limitation of v1 with null state intercept[Exception] { testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, Seq(null)) } @@ -87,7 +86,7 @@ class FlatMapGroupsWithStateExecHelperSuite extends StreamTest { testStateManagerWithTimestamp[NestedStruct](version = 1, schema, testValues) - // Verify the limitation of v1 with null + // Verify the limitation of v1 with null state intercept[Exception] { testStateManagerWithTimestamp[NestedStruct](version = 1, schema, Seq(null)) } From 9525484a444ce231ff366bc556fe5a1d46ac4d4f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 9 Jul 2018 10:38:43 -0700 Subject: [PATCH 4/9] Minor refactoring --- .../streaming/state/FlatMapGroupsWithStateExecHelper.scala | 7 +++++-- .../state/FlatMapGroupsWithStateExecHelperSuite.scala | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala index 508e29d4b2196..c1cfda19bac0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -160,7 +160,11 @@ object FlatMapGroupsWithStateExecHelper { } } - override val stateDeserializerExpr: Expression = stateEncoder.resolveAndBind().deserializer + override val stateDeserializerExpr: Expression = { + // Note that this must be done in the driver, as resolving and binding of deserializer + // expressions to the encoded type can be safely done only in the driver. + stateEncoder.resolveAndBind().deserializer + } override protected def getStateRow(obj: Any): UnsafeRow = { require(obj != null, "State object cannot be null") @@ -215,4 +219,3 @@ object FlatMapGroupsWithStateExecHelper { } } } - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala index 3165cb34fce65..24599525d715e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala @@ -65,7 +65,7 @@ class FlatMapGroupsWithStateExecHelperSuite extends StreamTest { // Verify the limitation of v1 with null state intercept[Exception] { - testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, Seq(null)) + testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, testValues = Seq(null)) } } @@ -88,7 +88,7 @@ class FlatMapGroupsWithStateExecHelperSuite extends StreamTest { // Verify the limitation of v1 with null state intercept[Exception] { - testStateManagerWithTimestamp[NestedStruct](version = 1, schema, Seq(null)) + testStateManagerWithTimestamp[NestedStruct](version = 1, schema, testValues = Seq(null)) } } From c9f600b0a5a0940b0eb76f9cd4f91ae90a5fe742 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 9 Jul 2018 20:27:24 -0700 Subject: [PATCH 5/9] Added conf and multi-version tests in FlatMapGroupsWithStateSuite --- .../apache/spark/sql/internal/SQLConf.scala | 8 ++++++ .../apache/spark/sql/types/ObjectType.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 5 ++-- .../FlatMapGroupsWithStateExec.scala | 4 ++- .../FlatMapGroupsWithStateExecHelper.scala | 18 +++++++------ ...latMapGroupsWithStateExecHelperSuite.scala | 4 +-- .../FlatMapGroupsWithStateSuite.scala | 25 +++++++++++++------ 7 files changed, 45 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e2c48e2d8a14c..8a12a74a73dff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -814,6 +814,14 @@ object SQLConf { .intConf .createWithDefault(10) + val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") + .internal() + .doc("State format version used by flatMapGroupsWithState operation in a streaming query") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation") .doc("The default location for storing checkpoint data for streaming queries.") .stringConf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 820f40f5bd10b..2d49fe076786a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -43,7 +43,7 @@ case class ObjectType(cls: Class[_]) extends DataType { def asNullable: DataType = this - override def simpleString: String = s"Object[${cls.getName}]" + override def simpleString: String = cls.getName override def acceptsType(other: DataType): Boolean = other match { case ObjectType(otherCls) => cls.isAssignableFrom(otherCls) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 07a6fcae83b70..c2bf40cb22064 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -485,9 +485,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case FlatMapGroupsWithState( func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _, timeout, child) => + val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) val execPlan = FlatMapGroupsWithStateExec( - func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode, - timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, stateVersion, + outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) execPlan :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 61640edb54941..bfe7d00f56048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -50,6 +50,7 @@ case class FlatMapGroupsWithStateExec( outputObjAttr: Attribute, stateInfo: Option[StatefulOperatorStateInfo], stateEncoder: ExpressionEncoder[Any], + stateFormatVersion: Int, outputMode: OutputMode, timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], @@ -65,7 +66,8 @@ case class FlatMapGroupsWithStateExec( case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false } - private[sql] val stateManager = createStateManager(stateEncoder, isTimeoutEnabled, 2) + private[sql] val stateManager = + createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion) /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala index c1cfda19bac0d..e827c41c05413 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, Expression, GenericInternalRow, GetStructField, If, IsNull, Literal, SpecificInternalRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.ObjectOperator import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._ object FlatMapGroupsWithStateExecHelper { - val DEFAULT_STATE_MANAGER_VERSION = 2 + val supportedVersions = Seq(1, 2) /** * Class to capture deserialized state and timestamp return by the state manager. @@ -58,16 +58,17 @@ object FlatMapGroupsWithStateExecHelper { def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timeoutTimestamp: Long): Unit def removeState(store: StateStore, keyRow: UnsafeRow): Unit def getAllState(store: StateStore): Iterator[StateData] + def version: Int } def createStateManager( stateEncoder: ExpressionEncoder[Any], shouldStoreTimestamp: Boolean, - version: Int): StateManager = { - version match { + stateFormatVersion: Int): StateManager = { + stateFormatVersion match { case 1 => new StateManagerImplV1(stateEncoder, shouldStoreTimestamp) case 2 => new StateManagerImplV2(stateEncoder, shouldStoreTimestamp) - case _ => throw new IllegalArgumentException(s"Version $version") + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") } } @@ -75,7 +76,8 @@ object FlatMapGroupsWithStateExecHelper { // =========================== Private implementations of StateManager =========================== // =============================================================================================== - private abstract class StateManagerImplBase(shouldStoreTimestamp: Boolean) extends StateManager { + private abstract class StateManagerImplBase(val version: Int, shouldStoreTimestamp: Boolean) + extends StateManager { protected def stateSerializerExprs: Seq[Expression] protected def stateDeserializerExpr: Expression @@ -135,7 +137,7 @@ object FlatMapGroupsWithStateExecHelper { private class StateManagerImplV1( stateEncoder: ExpressionEncoder[Any], - shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(1, shouldStoreTimestamp) { private val timestampTimeoutAttribute = AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() @@ -175,7 +177,7 @@ object FlatMapGroupsWithStateExecHelper { private class StateManagerImplV2( stateEncoder: ExpressionEncoder[Any], - shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(2, shouldStoreTimestamp) { /** Schema of the state rows saved in the state store */ override val stateSchema: StructType = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala index 24599525d715e..dec30fd01f7e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.streaming.state import java.util.concurrent.atomic.AtomicInteger -import org.apache.spark.sql.{Encoder, QueryTest} +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.GroupStateImpl._ import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite._ import org.apache.spark.sql.streaming.StreamTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index f8d138ec129e3..f5a7f3f1563f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec -import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} @@ -601,7 +602,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = Some(5), // state should change expectedTimeoutTimestamp = 5000) // timestamp should change - test("flatMapGroupsWithState - streaming") { + testWithAllStateVersions("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -680,7 +681,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } - test("flatMapGroupsWithState - streaming + aggregation") { + testWithAllStateVersions("flatMapGroupsWithState - streaming + aggregation") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -739,7 +740,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF) } - test("flatMapGroupsWithState - streaming with processing time timeout") { + testWithAllStateVersions("flatMapGroupsWithState - streaming with processing time timeout") { // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. // It returns the count if changed, or -1 if the state was removed by timeout. val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -803,7 +804,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } - test("flatMapGroupsWithState - streaming with event time timeout + watermark") { + testWithAllStateVersions("flatMapGroupsWithState - streaming with event time timeout") { // Function to maintain the max event time as state and set the timeout timestamp based on the // current max event time seen. It returns the max event time in the state, or -1 if the state // was removed by timeout. @@ -1135,7 +1136,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .logicalPlan.collectFirst { case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => FlatMapGroupsWithStateExec( - f, k, v, g, d, o, None, s, m, t, + f, k, v, g, d, o, None, s, 2, m, t, Some(currentBatchTimestamp), Some(currentBatchWatermark), RDDScanExec(g, null, "rdd")) }.get } @@ -1168,6 +1169,16 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } def rowToInt(row: UnsafeRow): Int = row.getInt(0) + + def testWithAllStateVersions(name: String)(func: => Unit): Unit = { + for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) { + test(s"$name - state format version $version") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> version.toString) { + func + } + } + } + } } object FlatMapGroupsWithStateSuite { From 05e3acf5c9ecde48298515652de1e7acda92946a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 11 Jul 2018 00:54:11 -0700 Subject: [PATCH 6/9] Use version 1 when recovering from state --- .../sql/catalyst/expressions/Expression.scala | 3 +- .../sql/execution/streaming/OffsetSeq.scala | 10 ++- .../FlatMapGroupsWithStateExecHelper.scala | 4 +- .../commits/0 | 2 + .../commits/1 | 2 + .../metadata | 1 + .../offsets/0 | 3 + .../offsets/1 | 3 + .../state/0/0/1.delta | Bin 0 -> 84 bytes .../state/0/0/2.delta | Bin 0 -> 46 bytes .../state/0/1/1.delta | Bin 0 -> 46 bytes .../state/0/1/2.delta | Bin 0 -> 46 bytes .../state/0/2/1.delta | Bin 0 -> 46 bytes .../state/0/2/2.delta | Bin 0 -> 46 bytes .../state/0/3/1.delta | Bin 0 -> 46 bytes .../state/0/3/2.delta | Bin 0 -> 46 bytes .../state/0/4/1.delta | Bin 0 -> 46 bytes .../state/0/4/2.delta | Bin 0 -> 46 bytes .../FlatMapGroupsWithStateSuite.scala | 70 ++++++++++++++++++ 19 files changed, 93 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 44c5556ff9ccf..ac9daad45c275 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -725,7 +725,8 @@ trait ComplexTypeMergingExpression extends Expression { "The collection of input data types must not be empty.") require( areInputTypesForMergingEqual, - "All input types must be the same except nullable, containsNull, valueContainsNull flags.") + "All input types must be the same except nullable, containsNull, valueContainsNull flags." + + s" The input types found are\n\t${inputTypesForMerging.mkString("\n\t")}") inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 1ae3f36c152cf..9847756f22d4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.internal.SQLConf._ +import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper +import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} /** * An ordered collection of offsets, used to track the progress of processing data from one or more @@ -87,7 +88,8 @@ case class OffsetSeqMetadata( object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) private val relevantSQLConfs = Seq( - SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY) + SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) /** * Default values of relevant configurations that are used for backward compatibility. @@ -100,7 +102,9 @@ object OffsetSeqMetadata extends Logging { * with a specific default value for ensuring same behavior of the query as before. */ private val relevantSQLConfDefaultValues = Map[String, String]( - STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME, + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> + FlatMapGroupsWithStateExecHelper.legacyVersion.toString ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala index e827c41c05413..247c7f9047bfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.types._ object FlatMapGroupsWithStateExecHelper { val supportedVersions = Seq(1, 2) + val legacyVersion = 1 /** * Class to capture deserialized state and timestamp return by the state manager. @@ -217,7 +218,8 @@ object FlatMapGroupsWithStateExecHelper { val deserExpr = stateEncoder.resolveAndBind().deserializer.transformUp { case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) } - CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deserExpr) + val nullLiteral = Literal(null, deserExpr.dataType) + CaseWhen(Seq(IsNull(boundRefToNestedState) -> nullLiteral), elseValue = deserExpr) } } } diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata new file mode 100644 index 0000000000000..372180b2096ee --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata @@ -0,0 +1 @@ +{"id":"04d960cd-d38f-4ce6-b8d0-ebcf84c9dccc"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 new file mode 100644 index 0000000000000..807d7b0063b96 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531292029003,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 new file mode 100644 index 0000000000000..cce541073fb4b --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":5000,"batchTimestampMs":1531292030005,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..193524ffe15b51c941eb08906f274e7708616f37 GIT binary patch literal 84 zcmeZ?GI7euPtI1=Vqjpf0pdRCZw$deT7rR*VKO6-AppdQ0t_5741)ap3=bF>6#Rf9 QK=2<3e4yGzAwm!m0E+|;=l}o! literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index f5a7f3f1563f6..8eafe4eb7251b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.streaming +import java.io.File import java.sql.Date import java.util.concurrent.ConcurrentHashMap +import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfterAll import org.scalatest.exceptions.TestFailedException @@ -36,6 +38,7 @@ import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExe import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.util.Utils /** Class to check custom state types */ case class RunningCount(count: Long) @@ -855,6 +858,73 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } + + test("flatMapGroupsWithState - recovery from checkpoint uses state format version 1") { + // Function to maintain the max event time as state and set the timeout timestamp based on the + // current max event time seen. It returns the max event time in the state, or -1 if the state + // was removed by timeout. + val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } + + val timeoutDelaySec = 5 + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) + } else { + val valuesSeq = values.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec + state.update(maxEventTimeSec) + state.setTimeoutTimestamp(timeoutTimestampSec * 1000) + Iterator((key, maxEventTimeSec.toInt)) + } + } + val inputData = MemoryStream[(String, Int)] + val result = + inputData.toDS + .select($"_1".as("key"), $"_2".cast("timestamp").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)] + .groupByKey(_._1) + .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(("a", 11), ("a", 13), ("a", 15)) + inputData.addData(("a", 4)) + + testStream(result, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + /* + Note: The checkpoint was generated using the following input in Spark version 2.3.1 + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark + */ + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 + ) + } + + test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) From dcf961671c4c9c54da7c7bf968ac880a2acb9ac3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 11 Jul 2018 01:05:46 -0700 Subject: [PATCH 7/9] Revert few unnecessary changes --- .../spark/sql/streaming/FlatMapGroupsWithStateSuite.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 8eafe4eb7251b..8910fd811aa63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -496,9 +496,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStateUpdateWithData( s"EventTimeTimeout - $testName - state and timeout timestamp updated", stateUpdates = - (state: GroupState[Int]) => { - state.update(5); state.setTimeoutTimestamp(5000) - }, + (state: GroupState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -517,8 +515,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = Some(5), // state should change - expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update testStateUpdateWithData( s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", From 3abb5e228a9803346b1dbf6492de976d949e402e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 11 Jul 2018 01:10:04 -0700 Subject: [PATCH 8/9] minor changes --- .../spark/sql/streaming/FlatMapGroupsWithStateSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 8910fd811aa63..5be5c80ab2d8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -805,7 +805,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } - testWithAllStateVersions("flatMapGroupsWithState - streaming with event time timeout") { + testWithAllStateVersions("flatMapGroupsWithState - streaming w\ event time timeout + watermark") { // Function to maintain the max event time as state and set the timeout timestamp based on the // current max event time seen. It returns the max event time in the state, or -1 if the state // was removed by timeout. @@ -1197,6 +1197,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest func: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int], timeoutType: GroupStateTimeout = GroupStateTimeout.NoTimeout, batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = { + val stateFormatVersion = spark.conf.get(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) MemoryStream[Int] .toDS .groupByKey(x => x) @@ -1204,7 +1205,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .logicalPlan.collectFirst { case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => FlatMapGroupsWithStateExec( - f, k, v, g, d, o, None, s, 2, m, t, + f, k, v, g, d, o, None, s, stateFormatVersion, m, t, Some(currentBatchTimestamp), Some(currentBatchWatermark), RDDScanExec(g, null, "rdd")) }.get } From c262e87afe8736febcb546827f0af22da14a02d9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 18 Jul 2018 18:52:48 -0700 Subject: [PATCH 9/9] Added docs and unit tests to verify format --- .../FlatMapGroupsWithStateExecHelper.scala | 34 ++++++++++++++---- .../FlatMapGroupsWithStateSuite.scala | 36 +++++++++++++++++-- 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala index 247c7f9047bfb..0a16a3819b778 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -53,13 +53,13 @@ object FlatMapGroupsWithStateExecHelper { } } + /** Interface for interacting with state data of FlatMapGroupsWithState */ sealed trait StateManager extends Serializable { def stateSchema: StructType def getState(store: StateStore, keyRow: UnsafeRow): StateData def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timeoutTimestamp: Long): Unit def removeState(store: StateStore, keyRow: UnsafeRow): Unit def getAllState(store: StateStore): Iterator[StateData] - def version: Int } def createStateManager( @@ -77,7 +77,8 @@ object FlatMapGroupsWithStateExecHelper { // =========================== Private implementations of StateManager =========================== // =============================================================================================== - private abstract class StateManagerImplBase(val version: Int, shouldStoreTimestamp: Boolean) + /** Commmon methods for StateManager implementations */ + private abstract class StateManagerImplBase(shouldStoreTimestamp: Boolean) extends StateManager { protected def stateSerializerExprs: Seq[Expression] @@ -135,10 +136,20 @@ object FlatMapGroupsWithStateExecHelper { } } - + /** + * Version 1 of the StateManager which stores the user-defined state as flattened columns in + * the UnsafeRow. Say the user-defined state has 3 fields - col1, col2, col3. The + * unsafe rows will look like this. + * + * UnsafeRow[ col1 | col2 | col3 | timestamp ] + * + * The limitation of this format is that timestamp cannot be set when the user-defined + * state has been removed. This is because the columns cannot be collectively marked to be + * empty/null. + */ private class StateManagerImplV1( stateEncoder: ExpressionEncoder[Any], - shouldStoreTimestamp: Boolean) extends StateManagerImplBase(1, shouldStoreTimestamp) { + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { private val timestampTimeoutAttribute = AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() @@ -175,10 +186,21 @@ object FlatMapGroupsWithStateExecHelper { } } - + /** + * Version 2 of the StateManager which stores the user-defined state as a nested struct + * in the UnsafeRow. Say the user-defined state has 3 fields - col1, col2, col3. The + * unsafe rows will look like this. + * ___________________________ + * | | + * | V + * UnsafeRow[ nested-struct | timestamp | UnsafeRow[ col1 | col2 | col3 ] ] + * + * This allows the entire user-defined state to be collectively marked as empty/null, + * thus allowing timestamp to be set without requiring the state to be present. + */ private class StateManagerImplV2( stateEncoder: ExpressionEncoder[Any], - shouldStoreTimestamp: Boolean) extends StateManagerImplBase(2, shouldStoreTimestamp) { + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { /** Schema of the state rows saved in the state store */ override val stateSchema: StructType = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 5be5c80ab2d8b..82d7755aef5f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -805,7 +805,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } - testWithAllStateVersions("flatMapGroupsWithState - streaming w\ event time timeout + watermark") { + testWithAllStateVersions("flatMapGroupsWithState - streaming w/ event time timeout + watermark") { // Function to maintain the max event time as state and set the timeout timestamp based on the // current max event time seen. It returns the max event time in the state, or -1 if the state // was removed by timeout. @@ -856,6 +856,29 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } + test("flatMapGroupsWithState - uses state format version 2 by default") { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + Iterator((key, count.toString)) + } + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { query => + // Verify state format = 2 + val f = query.lastExecution.executedPlan.collect { case f: FlatMapGroupsWithStateExec => f } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 2) + } + ) + } test("flatMapGroupsWithState - recovery from checkpoint uses state format version 1") { // Function to maintain the max event time as state and set the timeout timestamp based on the @@ -899,7 +922,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest inputData.addData(("a", 4)) testStream(result, Update)( - StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + StartStream( + checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "2")), /* Note: The checkpoint was generated using the following input in Spark version 2.3.1 @@ -916,6 +941,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. // Watermark is still 5 as max event time for all data is still 15. + Execute { query => + // Verify state format = 1 + val f = query.lastExecution.executedPlan.collect { case f: FlatMapGroupsWithStateExec => f } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 1) + }, + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1