Skip to content

Commit 6a6c1fc

Browse files
BryanCutlerdavies
authored andcommitted
[SPARK-11713] [PYSPARK] [STREAMING] Initial RDD updateStateByKey for PySpark
Adding ability to define an initial state RDD for use with updateStateByKey PySpark. Added unit test and changed stateful_network_wordcount example to use initial RDD. Author: Bryan Cutler <[email protected]> Closes #10082 from BryanCutler/initial-rdd-updateStateByKey-SPARK-11713.
1 parent 4a46b88 commit 6a6c1fc

File tree

4 files changed

+47
-5
lines changed

4 files changed

+47
-5
lines changed

examples/src/main/python/streaming/stateful_network_wordcount.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,16 @@
4444
ssc = StreamingContext(sc, 1)
4545
ssc.checkpoint("checkpoint")
4646

47+
# RDD with initial state (key, value) pairs
48+
initialStateRDD = sc.parallelize([(u'hello', 1), (u'world', 1)])
49+
4750
def updateFunc(new_values, last_sum):
4851
return sum(new_values) + (last_sum or 0)
4952

5053
lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
5154
running_counts = lines.flatMap(lambda line: line.split(" "))\
5255
.map(lambda word: (word, 1))\
53-
.updateStateByKey(updateFunc)
56+
.updateStateByKey(updateFunc, initialRDD=initialStateRDD)
5457

5558
running_counts.pprint()
5659

python/pyspark/streaming/dstream.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def invReduceFunc(t, a, b):
568568
self._ssc._jduration(slideDuration))
569569
return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
570570

571-
def updateStateByKey(self, updateFunc, numPartitions=None):
571+
def updateStateByKey(self, updateFunc, numPartitions=None, initialRDD=None):
572572
"""
573573
Return a new "state" DStream where the state for each key is updated by applying
574574
the given function on the previous state of the key and the new values of the key.
@@ -579,6 +579,9 @@ def updateStateByKey(self, updateFunc, numPartitions=None):
579579
if numPartitions is None:
580580
numPartitions = self._sc.defaultParallelism
581581

582+
if initialRDD and not isinstance(initialRDD, RDD):
583+
initialRDD = self._sc.parallelize(initialRDD)
584+
582585
def reduceFunc(t, a, b):
583586
if a is None:
584587
g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
@@ -590,7 +593,13 @@ def reduceFunc(t, a, b):
590593

591594
jreduceFunc = TransformFunction(self._sc, reduceFunc,
592595
self._sc.serializer, self._jrdd_deserializer)
593-
dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
596+
if initialRDD:
597+
initialRDD = initialRDD._reserialize(self._jrdd_deserializer)
598+
dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc,
599+
initialRDD._jrdd)
600+
else:
601+
dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
602+
594603
return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
595604

596605

python/pyspark/streaming/tests.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,26 @@ def func(dstream):
403403
expected = [[('k', v)] for v in expected]
404404
self._test_func(input, func, expected)
405405

406+
def test_update_state_by_key_initial_rdd(self):
407+
408+
def updater(vs, s):
409+
if not s:
410+
s = []
411+
s.extend(vs)
412+
return s
413+
414+
initial = [('k', [0, 1])]
415+
initial = self.sc.parallelize(initial, 1)
416+
417+
input = [[('k', i)] for i in range(2, 5)]
418+
419+
def func(dstream):
420+
return dstream.updateStateByKey(updater, initialRDD=initial)
421+
422+
expected = [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
423+
expected = [[('k', v)] for v in expected]
424+
self._test_func(input, func, expected)
425+
406426
def test_failed_func(self):
407427
# Test failure in
408428
# TransformFunction.apply(rdd: Option[RDD[_]], time: Time)

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,17 +264,27 @@ private[python] class PythonTransformed2DStream(
264264
*/
265265
private[python] class PythonStateDStream(
266266
parent: DStream[Array[Byte]],
267-
reduceFunc: PythonTransformFunction)
267+
reduceFunc: PythonTransformFunction,
268+
initialRDD: Option[RDD[Array[Byte]]])
268269
extends PythonDStream(parent, reduceFunc) {
269270

271+
def this(
272+
parent: DStream[Array[Byte]],
273+
reduceFunc: PythonTransformFunction) = this(parent, reduceFunc, None)
274+
275+
def this(
276+
parent: DStream[Array[Byte]],
277+
reduceFunc: PythonTransformFunction,
278+
initialRDD: JavaRDD[Array[Byte]]) = this(parent, reduceFunc, Some(initialRDD.rdd))
279+
270280
super.persist(StorageLevel.MEMORY_ONLY)
271281
override val mustCheckpoint = true
272282

273283
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
274284
val lastState = getOrCompute(validTime - slideDuration)
275285
val rdd = parent.getOrCompute(validTime)
276286
if (rdd.isDefined) {
277-
func(lastState, rdd, validTime)
287+
func(lastState.orElse(initialRDD), rdd, validTime)
278288
} else {
279289
lastState
280290
}

0 commit comments

Comments
 (0)