Skip to content

Commit a79a120

Browse files
tdaszsxwing
authored andcommitted
[SPARK-20717][SS] Minor tweaks to the MapGroupsWithState behavior
## What changes were proposed in this pull request? Timeout and state data are two independent entities and should be settable independently. Therefore, in the same call of the user-defined function, one should be able to set the timeout before initializing the state and also after removing the state. Whether timeouts can be set or not, should not depend on the current state, and vice versa. However, a limitation of the current implementation is that state cannot be null while timeout is set. This is checked lazily after the function call has completed. ## How was this patch tested? - Updated existing unit tests that test the behavior of GroupState.setTimeout*** wrt to the current state - Added new tests that verify the disallowed cases where state is undefined but timeout is set. Author: Tathagata Das <[email protected]> Closes #17957 from tdas/SPARK-20717. (cherry picked from commit 499ba2c) Signed-off-by: Shixiong Zhu <[email protected]>
1 parent 82ae1f0 commit a79a120

File tree

4 files changed

+139
-55
lines changed

4 files changed

+139
-55
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,20 @@ case class FlatMapGroupsWithStateExec(
230230

231231
// When the iterator is consumed, then write changes to state
232232
def onIteratorCompletion: Unit = {
233+
234+
val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp
235+
// If the state has not yet been set but timeout has been set, then
236+
// we have to generate a row to save the timeout. However, attempting serialize
237+
// null using case class encoder throws -
238+
// java.lang.NullPointerException: Null value appeared in non-nullable field:
239+
// If the schema is inferred from a Scala tuple / case class, or a Java bean, please
240+
// try to use scala.Option[_] or other nullable types.
241+
if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) {
242+
throw new IllegalStateException(
243+
"Cannot set timeout when state is not defined, that is, state has not been" +
244+
"initialized or has been removed")
245+
}
246+
233247
if (keyedState.hasRemoved) {
234248
store.remove(keyRow)
235249
numUpdatedStateRows += 1
@@ -239,7 +253,6 @@ case class FlatMapGroupsWithStateExec(
239253
case Some(row) => getTimeoutTimestamp(row)
240254
case None => NO_TIMESTAMP
241255
}
242-
val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp
243256
val stateRowToWrite = if (keyedState.hasUpdated) {
244257
getStateRow(keyedState.get)
245258
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ private[sql] class GroupStateImpl[S](
9191
defined = false
9292
updated = false
9393
removed = true
94-
timeoutTimestamp = NO_TIMESTAMP
9594
}
9695

9796
override def setTimeoutDuration(durationMs: Long): Unit = {
@@ -100,16 +99,10 @@ private[sql] class GroupStateImpl[S](
10099
"Cannot set timeout duration without enabling processing time timeout in " +
101100
"map/flatMapGroupsWithState")
102101
}
103-
if (!defined) {
104-
throw new IllegalStateException(
105-
"Cannot set timeout information without any state value, " +
106-
"state has either not been initialized, or has already been removed")
107-
}
108-
109102
if (durationMs <= 0) {
110103
throw new IllegalArgumentException("Timeout duration must be positive")
111104
}
112-
if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) {
105+
if (batchProcessingTimeMs != NO_TIMESTAMP) {
113106
timeoutTimestamp = durationMs + batchProcessingTimeMs
114107
} else {
115108
// This is being called in a batch query, hence no processing timestamp.
@@ -135,7 +128,7 @@ private[sql] class GroupStateImpl[S](
135128
s"Timeout timestamp ($timestampMs) cannot be earlier than the " +
136129
s"current watermark ($eventTimeWatermarkMs)")
137130
}
138-
if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) {
131+
if (batchProcessingTimeMs != NO_TIMESTAMP) {
139132
timeoutTimestamp = timestampMs
140133
} else {
141134
// This is being called in a batch query, hence no processing timestamp.
@@ -213,11 +206,6 @@ private[sql] class GroupStateImpl[S](
213206
"Cannot set timeout timestamp without enabling event time timeout in " +
214207
"map/flatMapGroupsWithState")
215208
}
216-
if (!defined) {
217-
throw new IllegalStateException(
218-
"Cannot set timeout timestamp without any state value, " +
219-
"state has either not been initialized, or has already been removed")
220-
}
221209
}
222210
}
223211

sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ trait GroupState[S] extends LogicalGroupState[S] {
212212
@throws[IllegalArgumentException]("when updating with null")
213213
def update(newState: S): Unit
214214

215-
/** Remove this state. Note that this resets any timeout configuration as well. */
215+
/** Remove this state. */
216216
def remove(): Unit
217217

218218
/**

sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala

Lines changed: 122 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,22 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
112112

113113
state = new GroupStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
114114
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
115-
testTimeoutDurationNotAllowed[IllegalStateException](state)
115+
state.setTimeoutDuration(500)
116+
assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state
116117
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
117118

118119
state.update(5)
119-
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
120+
assert(state.getTimeoutTimestamp === 1500) // does not change
120121
state.setTimeoutDuration(1000)
121122
assert(state.getTimeoutTimestamp === 2000)
122123
state.setTimeoutDuration("2 second")
123124
assert(state.getTimeoutTimestamp === 3000)
124125
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
125126

126127
state.remove()
127-
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
128-
testTimeoutDurationNotAllowed[IllegalStateException](state)
128+
assert(state.getTimeoutTimestamp === 3000) // does not change
129+
state.setTimeoutDuration(500) // can still be set
130+
assert(state.getTimeoutTimestamp === 1500)
129131
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
130132
}
131133

@@ -134,46 +136,77 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
134136
None, 1000, 1000, EventTimeTimeout, hasTimedOut = false)
135137
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
136138
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
137-
testTimeoutTimestampNotAllowed[IllegalStateException](state)
139+
state.setTimeoutTimestamp(5000)
140+
assert(state.getTimeoutTimestamp === 5000) // can be set without initializing state
138141

139142
state.update(5)
143+
assert(state.getTimeoutTimestamp === 5000) // does not change
140144
state.setTimeoutTimestamp(10000)
141145
assert(state.getTimeoutTimestamp === 10000)
142146
state.setTimeoutTimestamp(new Date(20000))
143147
assert(state.getTimeoutTimestamp === 20000)
144148
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
145149

146150
state.remove()
147-
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
151+
assert(state.getTimeoutTimestamp === 20000)
152+
state.setTimeoutTimestamp(5000)
153+
assert(state.getTimeoutTimestamp === 5000) // can be set after removing state
148154
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
149-
testTimeoutTimestampNotAllowed[IllegalStateException](state)
150155
}
151156

152157
test("GroupState - illegal params to setTimeout****") {
153158
var state: GroupStateImpl[Int] = null
154159

155160
// Test setTimeout****() with illegal values
156161
def testIllegalTimeout(body: => Unit): Unit = {
157-
intercept[IllegalArgumentException] { body }
162+
intercept[IllegalArgumentException] {
163+
body
164+
}
158165
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
159166
}
160167

161168
state = new GroupStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
162-
testIllegalTimeout { state.setTimeoutDuration(-1000) }
163-
testIllegalTimeout { state.setTimeoutDuration(0) }
164-
testIllegalTimeout { state.setTimeoutDuration("-2 second") }
165-
testIllegalTimeout { state.setTimeoutDuration("-1 month") }
166-
testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") }
169+
testIllegalTimeout {
170+
state.setTimeoutDuration(-1000)
171+
}
172+
testIllegalTimeout {
173+
state.setTimeoutDuration(0)
174+
}
175+
testIllegalTimeout {
176+
state.setTimeoutDuration("-2 second")
177+
}
178+
testIllegalTimeout {
179+
state.setTimeoutDuration("-1 month")
180+
}
181+
testIllegalTimeout {
182+
state.setTimeoutDuration("1 month -1 day")
183+
}
167184

168185
state = new GroupStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false)
169-
testIllegalTimeout { state.setTimeoutTimestamp(-10000) }
170-
testIllegalTimeout { state.setTimeoutTimestamp(10000, "-3 second") }
171-
testIllegalTimeout { state.setTimeoutTimestamp(10000, "-1 month") }
172-
testIllegalTimeout { state.setTimeoutTimestamp(10000, "1 month -1 day") }
173-
testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000)) }
174-
testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-3 second") }
175-
testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-1 month") }
176-
testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") }
186+
testIllegalTimeout {
187+
state.setTimeoutTimestamp(-10000)
188+
}
189+
testIllegalTimeout {
190+
state.setTimeoutTimestamp(10000, "-3 second")
191+
}
192+
testIllegalTimeout {
193+
state.setTimeoutTimestamp(10000, "-1 month")
194+
}
195+
testIllegalTimeout {
196+
state.setTimeoutTimestamp(10000, "1 month -1 day")
197+
}
198+
testIllegalTimeout {
199+
state.setTimeoutTimestamp(new Date(-10000))
200+
}
201+
testIllegalTimeout {
202+
state.setTimeoutTimestamp(new Date(-10000), "-3 second")
203+
}
204+
testIllegalTimeout {
205+
state.setTimeoutTimestamp(new Date(-10000), "-1 month")
206+
}
207+
testIllegalTimeout {
208+
state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day")
209+
}
177210
}
178211

179212
test("GroupState - hasTimedOut") {
@@ -318,6 +351,44 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
318351
}
319352
}
320353

354+
// Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(),
355+
// Try to remove these cases in the future
356+
for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) {
357+
val testName =
358+
if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout"
359+
testStateUpdateWithData(
360+
s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed",
361+
stateUpdates = state => { state.setTimeoutDuration(5000) },
362+
timeoutConf = ProcessingTimeTimeout,
363+
priorState = None,
364+
priorTimeoutTimestamp = priorTimeoutTimestamp,
365+
expectedException = classOf[IllegalStateException])
366+
367+
testStateUpdateWithData(
368+
s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed",
369+
stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) },
370+
timeoutConf = ProcessingTimeTimeout,
371+
priorState = Some(5),
372+
priorTimeoutTimestamp = priorTimeoutTimestamp,
373+
expectedException = classOf[IllegalStateException])
374+
375+
testStateUpdateWithData(
376+
s"EventTimeTimeout - $testName - setting timeout without init state not allowed",
377+
stateUpdates = state => { state.setTimeoutTimestamp(10000) },
378+
timeoutConf = EventTimeTimeout,
379+
priorState = None,
380+
priorTimeoutTimestamp = priorTimeoutTimestamp,
381+
expectedException = classOf[IllegalStateException])
382+
383+
testStateUpdateWithData(
384+
s"EventTimeTimeout - $testName - setting timeout with state removal not allowed",
385+
stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) },
386+
timeoutConf = EventTimeTimeout,
387+
priorState = Some(5),
388+
priorTimeoutTimestamp = priorTimeoutTimestamp,
389+
expectedException = classOf[IllegalStateException])
390+
}
391+
321392
// Tests for StateStoreUpdater.updateStateForTimedOutKeys()
322393
val preTimeoutState = Some(5)
323394
for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) {
@@ -806,7 +877,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
806877
priorState: Option[Int],
807878
priorTimeoutTimestamp: Long = NO_TIMESTAMP,
808879
expectedState: Option[Int] = None,
809-
expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = {
880+
expectedTimeoutTimestamp: Long = NO_TIMESTAMP,
881+
expectedException: Class[_ <: Exception] = null): Unit = {
810882

811883
if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) {
812884
return // there can be no prior timestamp, when there is no prior state
@@ -820,7 +892,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
820892
}
821893
testStateUpdate(
822894
testTimeoutUpdates = false, mapGroupsFunc, timeoutConf,
823-
priorState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp)
895+
priorState, priorTimeoutTimestamp,
896+
expectedState, expectedTimeoutTimestamp, expectedException)
824897
}
825898
}
826899

@@ -839,9 +912,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
839912
stateUpdates(state)
840913
Iterator.empty
841914
}
915+
842916
testStateUpdate(
843917
testTimeoutUpdates = true, mapGroupsFunc, timeoutConf = timeoutConf,
844-
preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp)
918+
preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp, null)
845919
}
846920
}
847921

@@ -852,7 +926,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
852926
priorState: Option[Int],
853927
priorTimeoutTimestamp: Long,
854928
expectedState: Option[Int],
855-
expectedTimeoutTimestamp: Long): Unit = {
929+
expectedTimeoutTimestamp: Long,
930+
expectedException: Class[_ <: Exception]): Unit = {
856931

857932
val store = newStateStore()
858933
val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec(
@@ -867,22 +942,30 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
867942
}
868943

869944
// Call updating function to update state store
870-
val returnedIter = if (testTimeoutUpdates) {
871-
updater.updateStateForTimedOutKeys()
872-
} else {
873-
updater.updateStateForKeysWithData(Iterator(key))
945+
def callFunction() = {
946+
val returnedIter = if (testTimeoutUpdates) {
947+
updater.updateStateForTimedOutKeys()
948+
} else {
949+
updater.updateStateForKeysWithData(Iterator(key))
950+
}
951+
returnedIter.size // consume the iterator to force state updates
874952
}
875-
returnedIter.size // consumer the iterator to force state updates
876-
877-
// Verify updated state in store
878-
val updatedStateRow = store.get(key)
879-
assert(
880-
updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState,
881-
"final state not as expected")
882-
if (updatedStateRow.nonEmpty) {
953+
if (expectedException != null) {
954+
// Call function and verify the exception type
955+
val e = intercept[Exception] { callFunction() }
956+
assert(e.getClass === expectedException, "Exception thrown but of the wrong type")
957+
} else {
958+
// Call function to update and verify updated state in store
959+
callFunction()
960+
val updatedStateRow = store.get(key)
883961
assert(
884-
updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp,
885-
"final timeout timestamp not as expected")
962+
updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState,
963+
"final state not as expected")
964+
if (updatedStateRow.nonEmpty) {
965+
assert(
966+
updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp,
967+
"final timeout timestamp not as expected")
968+
}
886969
}
887970
}
888971

0 commit comments

Comments
 (0)