Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public static void main(String[] args) {
JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
ssc.checkpoint(".");

// Initial RDD input to trackStateByKey
// Initial state RDD input to mapWithState
@SuppressWarnings("unchecked")
List<Tuple2<String, Integer>> tuples = Arrays.asList(new Tuple2<String, Integer>("hello", 1),
new Tuple2<String, Integer>("world", 1));
Expand All @@ -90,21 +90,21 @@ public Tuple2<String, Integer> call(String s) {
});

// Update the cumulative count function
final Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>> trackStateFunc =
new Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>>() {
final Function3<String, Optional<Integer>, State<Integer>, Tuple2<String, Integer>> mappingFunc =
new Function3<String, Optional<Integer>, State<Integer>, Tuple2<String, Integer>>() {

@Override
public Optional<Tuple2<String, Integer>> call(Time time, String word, Optional<Integer> one, State<Integer> state) {
public Tuple2<String, Integer> call(String word, Optional<Integer> one, State<Integer> state) {
int sum = one.or(0) + (state.exists() ? state.get() : 0);
Tuple2<String, Integer> output = new Tuple2<String, Integer>(word, sum);
state.update(sum);
return Optional.of(output);
return output;
}
};

// This will give a Dstream made of state (which is the cumulative count of the words)
JavaTrackStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD));
// DStream made of get cumulative counts that get updated in every batch
JavaMapWithStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
wordsDstream.mapWithState(StateSpec.function(mappingFunc).initialState(initialRDD));

stateDstream.print();
ssc.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ object StatefulNetworkWordCount {
val ssc = new StreamingContext(sparkConf, Seconds(1))
ssc.checkpoint(".")

// Initial RDD input to trackStateByKey
// Initial state RDD for mapWithState operation
val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))

// Create a ReceiverInputDStream on target ip:port and count the
Expand All @@ -58,17 +58,17 @@ object StatefulNetworkWordCount {
val words = lines.flatMap(_.split(" "))
val wordDstream = words.map(x => (x, 1))

// Update the cumulative count using updateStateByKey
// Update the cumulative count using mapWithState
// This will give a DStream made of state (which is the cumulative count of the words)
val trackStateFunc = (batchTime: Time, word: String, one: Option[Int], state: State[Int]) => {
val mappingFunc = (word: String, one: Option[Int], state: State[Int]) => {
val sum = one.getOrElse(0) + state.getOption.getOrElse(0)
val output = (word, sum)
state.update(sum)
Some(output)
output
}

val stateDstream = wordDstream.trackStateByKey(
StateSpec.function(trackStateFunc).initialState(initialRDD))
val stateDstream = wordDstream.mapWithState(
StateSpec.function(mappingFunc).initialState(initialRDD))
stateDstream.print()
ssc.start()
ssc.awaitTermination()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@
import org.apache.spark.HashPartitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.Function4;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
import org.apache.spark.streaming.api.java.JavaMapWithStateDStream;

/**
* Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8
Expand Down Expand Up @@ -863,12 +861,12 @@ public void testFlatMapValues() {
/**
* This test is only for testing the APIs. It's not necessary to run it.
*/
public void testTrackStateByAPI() {
public void testMapWithStateAPI() {
JavaPairRDD<String, Boolean> initialRDD = null;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the method name testTrackStateByAPI should be renamed to testMapWithStateAPI

JavaPairDStream<String, Integer> wordsDstream = null;

JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
wordsDstream.trackStateByKey(
JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream =
wordsDstream.mapWithState(
StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state) -> {
// Use all State's methods here
state.exists();
Expand All @@ -884,9 +882,9 @@ StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state)

JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();

JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
wordsDstream.trackStateByKey(
StateSpec.<String, Integer, Boolean, Double>function((value, state) -> {
JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream2 =
wordsDstream.mapWithState(
StateSpec.<String, Integer, Boolean, Double>function((key, value, state) -> {
state.exists();
state.get();
state.isTimingOut();
Expand All @@ -898,6 +896,6 @@ StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state)
.partitioner(new HashPartitioner(10))
.timeout(Durations.seconds(10)));

JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots();
JavaPairDStream<String, Boolean> mappedDStream = stateDstream2.stateSnapshots();
}
}
20 changes: 11 additions & 9 deletions streaming/src/main/scala/org/apache/spark/streaming/State.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ import org.apache.spark.annotation.Experimental

/**
* :: Experimental ::
* Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of
* a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
* Abstract class for getting and updating the state in mapping function used in the `mapWithState`
* operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala)
* or a [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
*
* Scala example of using `State`:
* {{{
* // A tracking function that maintains an integer state and return a String
* def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = {
* // A mapping function that maintains an integer state and returns a String
* def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
* // Check if state exists
* if (state.exists) {
* val existingState = state.get // Get the existing state
Expand All @@ -52,12 +52,12 @@ import org.apache.spark.annotation.Experimental
*
* Java example of using `State`:
* {{{
* // A tracking function that maintains an integer state and return a String
* Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
* new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
* // A mapping function that maintains an integer state and returns a String
* Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
* new Function3<String, Optional<Integer>, State<Integer>, String>() {
*
* @Override
* public Optional<String> call(Optional<Integer> one, State<Integer> state) {
* public String call(String key, Optional<Integer> value, State<Integer> state) {
* if (state.exists()) {
* int existingState = state.get(); // Get the existing state
* boolean shouldRemove = ...; // Decide whether to remove the state
Expand All @@ -75,6 +75,8 @@ import org.apache.spark.annotation.Experimental
* }
* };
* }}}
*
* @tparam S Class of the state
*/
@Experimental
sealed abstract class State[S] {
Expand Down
Loading