@@ -22,6 +22,7 @@ import scala.reflect.ClassTag
2222import org .apache .spark .annotation .Experimental
2323import org .apache .spark .api .java .JavaPairRDD
2424import org .apache .spark .rdd .RDD
25+ import org .apache .spark .util .ClosureCleaner
2526import org .apache .spark .{HashPartitioner , Partitioner }
2627
2728
@@ -37,28 +38,33 @@ import org.apache.spark.{HashPartitioner, Partitioner}
3738 *
3839 * Example in Scala:
3940 * {{{
40- * val spec = StateSpec(trackingFunction).numPartitions(10)
41+ * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
42+ * ...
43+ * }
44+ *
45+ * val spec = StateSpec.function(trackingFunction).numPartitions(10)
4146 *
4247 * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
4348 * }}}
4449 *
4550 * Example in Java:
4651 * {{{
47- * StateStateSpec[StateType, EmittedDataType] spec =
48- * StateStateSpec.create[StateType, EmittedDataType](trackingFunction).numPartition(10);
52+ * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
53+ * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
54+ * .numPartition(10);
4955 *
5056 * JavaDStream[EmittedDataType] emittedRecordDStream =
5157 * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
5258 * }}}
5359 */
5460@ Experimental
55- sealed abstract class StateSpec [K , V , S , T ] extends Serializable {
61+ sealed abstract class StateSpec [KeyType , ValueType , StateType , EmittedType ] extends Serializable {
5662
5763 /** Set the RDD containing the initial states that will be used by `trackStateByKey`*/
58- def initialState (rdd : RDD [(K , S )]): this .type
64+ def initialState (rdd : RDD [(KeyType , StateType )]): this .type
5965
6066 /** Set the RDD containing the initial states that will be used by `trackStateByKey`*/
61- def initialState (javaPairRDD : JavaPairRDD [K , S ]): this .type
67+ def initialState (javaPairRDD : JavaPairRDD [KeyType , StateType ]): this .type
6268
6369 /**
6470 * Set the number of partitions by which the state RDDs generated by `trackStateByKey`
@@ -93,15 +99,20 @@ sealed abstract class StateSpec[K, V, S, T] extends Serializable {
9399 *
94100 * Example in Scala:
95101 * {{{
96- * val spec = StateSpec(trackingFunction).numPartitions(10)
102+ * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
103+ * ...
104+ * }
105+ *
106+ * val spec = StateSpec.function(trackingFunction).numPartitions(10)
97107 *
98108 * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
99109 * }}}
100110 *
101111 * Example in Java:
102112 * {{{
103- * StateStateSpec[StateType, EmittedDataType] spec =
104- * StateStateSpec.create[StateType, EmittedDataType](trackingFunction).numPartition(10);
113+ * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
114+ * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
115+ * .numPartition(10);
105116 *
106117 * JavaDStream[EmittedDataType] emittedRecordDStream =
107118 * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
@@ -115,16 +126,17 @@ object StateSpec {
115126 * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream ]] (Scala) or a
116127 * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream ]] (Java).
117128 * @param trackingFunction The function applied on every data item to manage the associated state
118- * and generate the emitted data and
129+ * and generate the emitted data
119130 * @tparam KeyType Class of the keys
120131 * @tparam ValueType Class of the values
121132 * @tparam StateType Class of the states data
122133 * @tparam EmittedType Class of the emitted data
123134 */
124- def apply [KeyType , ValueType , StateType , EmittedType ](
125- trackingFunction : (KeyType , Option [ValueType ], State [StateType ]) => Option [EmittedType ]
135+ def function [KeyType , ValueType , StateType , EmittedType ](
136+ trackingFunction : (Time , KeyType , Option [ValueType ], State [StateType ]) => Option [EmittedType ]
126137 ): StateSpec [KeyType , ValueType , StateType , EmittedType ] = {
127- new StateSpecImpl [KeyType , ValueType , StateType , EmittedType ](trackingFunction)
138+ ClosureCleaner .clean(trackingFunction, checkSerializable = true )
139+ new StateSpecImpl (trackingFunction)
128140 }
129141
130142 /**
@@ -133,24 +145,28 @@ object StateSpec {
133145 * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream ]] (Scala) or a
134146 * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream ]] (Java).
135147 * @param trackingFunction The function applied on every data item to manage the associated state
136- * and generate the emitted data and
137- * @tparam KeyType Class of the keys
148+ * and generate the emitted data
138149 * @tparam ValueType Class of the values
139150 * @tparam StateType Class of the states data
140151 * @tparam EmittedType Class of the emitted data
141152 */
142- def create [KeyType , ValueType , StateType , EmittedType ](
143- trackingFunction : (KeyType , Option [ValueType ], State [StateType ]) => Option [EmittedType ]
144- ): StateSpec [KeyType , ValueType , StateType , EmittedType ] = {
145- apply(trackingFunction)
153+ def function [ValueType , StateType , EmittedType ](
154+ trackingFunction : (Option [ValueType ], State [StateType ]) => EmittedType
155+ ): StateSpec [Any , ValueType , StateType , EmittedType ] = {
156+ ClosureCleaner .clean(trackingFunction, checkSerializable = true )
157+ val wrappedFunction =
158+ (time : Time , key : Any , value : Option [ValueType ], state : State [StateType ]) => {
159+ Some (trackingFunction(value, state))
160+ }
161+ new StateSpecImpl [Any , ValueType , StateType , EmittedType ](wrappedFunction)
146162 }
147163}
148164
149165
150166/** Internal implementation of [[org.apache.spark.streaming.StateSpec ]] interface. */
151167private [streaming]
152168case class StateSpecImpl [K , V , S , T ](
153- function : (K , Option [V ], State [S ]) => Option [T ]) extends StateSpec [K , V , S , T ] {
169+ function : (Time , K , Option [V ], State [S ]) => Option [T ]) extends StateSpec [K , V , S , T ] {
154170
155171 require(function != null )
156172
@@ -186,7 +202,7 @@ case class StateSpecImpl[K, V, S, T](
186202
187203 // ================= Private Methods =================
188204
189- private [streaming] def getFunction (): (K , Option [V ], State [S ]) => Option [T ] = function
205+ private [streaming] def getFunction (): (Time , K , Option [V ], State [S ]) => Option [T ] = function
190206
191207 private [streaming] def getInitialStateRDD (): Option [RDD [(K , S )]] = Option (initialStateRDD)
192208
0 commit comments