Skip to content

Commit a1d3f75

Browse files
committed
Refactored to address PR comments
1 parent 0c5fe55 commit a1d3f75

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,18 +138,16 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
138138
// If the RDD is not partitioned the right way, let us repartition it using the
139139
// partition index as the key. This is to ensure that state RDD is always partitioned
140140
// before creating another state RDD using it
141-
val kvRDD = rdd.mapPartitions { iter =>
142-
iter.map { x => (TaskContext.get().partitionId(), x)}
143-
}
144-
kvRDD.partitionBy(partitioner).mapPartitions(iter => iter.map { _._2 },
145-
preservesPartitioning = true)
141+
TrackStateRDD.createFromRDD[K, V, S, E](
142+
rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
146143
} else {
147144
rdd
148145
}
149146
case None =>
150147
TrackStateRDD.createFromPairRDD[K, V, S, E](
151148
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
152-
partitioner, validTime
149+
partitioner,
150+
validTime
153151
)
154152
}
155153

streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,22 +179,43 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
179179

180180
private[streaming] object TrackStateRDD {
181181

182-
def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
182+
def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
183183
pairRDD: RDD[(K, S)],
184184
partitioner: Partitioner,
185-
updateTime: Time): TrackStateRDD[K, V, S, T] = {
185+
updateTime: Time): TrackStateRDD[K, V, S, E] = {
186186

187187
val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
188188
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
189189
iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
190-
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T]))
190+
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
191191
}, preservesPartitioning = true)
192192

193193
val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
194194

195195
val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
196196

197-
new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
197+
new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
198+
}
199+
200+
def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
201+
rdd: RDD[(K, S, Long)],
202+
partitioner: Partitioner,
203+
updateTime: Time): TrackStateRDD[K, V, S, E] = {
204+
205+
val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
206+
val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
207+
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
208+
iterator.foreach { case (key, (state, updateTime)) =>
209+
stateMap.put(key, state, updateTime)
210+
}
211+
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
212+
}, preservesPartitioning = true)
213+
214+
val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
215+
216+
val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
217+
218+
new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
198219
}
199220
}
200221

streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ class TrackStateByKeySuite extends SparkFunSuite
478478
}
479479

480480

481-
test("trackStateByKey - drivery failure recovery") {
481+
test("trackStateByKey - driver failure recovery") {
482482
val inputData =
483483
Seq(
484484
Seq(),

0 commit comments

Comments
 (0)