From 7f15e29a4c2500cadbd9fd38a7aca5e58e2b9487 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 8 Oct 2015 01:22:22 -0700 Subject: [PATCH 01/26] First draft of sessionByKey --- .../streaming/StatefulNetworkWordCount.scala | 14 +- .../dstream/PairDStreamFunctions.scala | 10 +- .../streaming/dstream/SessionDStream.scala | 359 ++++++++++++++++++ .../streaming/BasicOperationsSuite.scala | 71 +++- .../spark/streaming/SessionMapSuite.scala | 82 ++++ 5 files changed, 526 insertions(+), 10 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/SessionMapSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 02ba1c2eed0f..944ad2be159b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -19,8 +19,8 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf -import org.apache.spark.HashPartitioner import org.apache.spark.streaming._ +import org.apache.spark.streaming.dstream.SessionSpec /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every @@ -44,6 +44,7 @@ object StatefulNetworkWordCount { StreamingExamples.setStreamingLogLevels() + /* val updateFunc = (values: Seq[Int], state: Option[Int]) => { val currentCount = values.sum @@ -54,7 +55,7 @@ object StatefulNetworkWordCount { val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) - } + }*/ val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount") // Create the context with a 1 second batch size @@ -72,8 +73,13 @@ object StatefulNetworkWordCount { // Update the cumulative count using updateStateByKey // This will give a Dstream made of state (which is the cumulative count of the words) - val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc, - new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD) + + val updateFunc = (value: Int, sessionData: Option[Int]) => { + Option(value + sessionData.getOrElse(0)) + } + + val stateDstream = wordDstream.sessionByKey[Int]( + SessionSpec.create(updateFunc).reportAllSession(true)) stateDstream.print() ssc.start() ssc.awaitTermination() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 71bec96d46c8..c29593465cf7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -35,8 +35,7 @@ import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} */ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K]) - extends Serializable -{ + extends Serializable { private[streaming] def ssc = self.ssc private[streaming] def sparkContext = self.context.sparkContext @@ -350,6 +349,13 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) ) } + def sessionByKey[S: ClassTag](sessionSpec: SessionSpec[K, V, S]): DStream[Session[K, S]] = { + new SessionDStream[K, V, S](self, sessionSpec).mapPartitions { partitionIter => + partitionIter.flatMap { _.iterator(!sessionSpec.getAllSessions()) } + } + } + + /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala new file mode 100644 index 000000000000..077179efa0a3 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala @@ -0,0 +1,359 @@ +package org.apache.spark.streaming.dstream + +import java.io.{IOException, ObjectOutputStream} + +import scala.collection.mutable +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Duration, Time} +import org.apache.spark.util.Utils + + +// ================================================== +// ================================================== +// ================= PUBLIC CLASSES ================= +// ================================================== +// ================================================== + + +/** Represents a session */ +case class Session[K, S] private[streaming]( + private var key: K, private var data: S, private var active: Boolean) { + + def this() = this(null.asInstanceOf[K], null.asInstanceOf[S], true) + + private[streaming] def set(k: K, s: S, a: Boolean): this.type = { + key = k + data = s + active = a + this + } + + /** Get the session key */ + def getKey(): K = key + + /** Get the session value */ + def getData(): S = data + + /** Whether the session is active */ + def isActive(): Boolean = active +/* + override def toString(): String = { + s"Session[ Key=$key, Data=$session, active=$active ]" + }*/ +} + +private[streaming] object Session { + +} + +/** Class representing all the specification of session */ +class SessionSpec[K: ClassTag, V: ClassTag, S: ClassTag] private[streaming]() extends Serializable { + @volatile private var updateFunction: (V, Option[S]) => Option[S] = null + @volatile private var partitioner: Partitioner = null + @volatile private var initialSessionRDD: RDD[(K, S)] = null + @volatile private var allSessions: Boolean = false + + def setPartition(partitioner: Partitioner): this.type = { + this.partitioner = partitioner + this + } + + def setUpdateFunction(func: (V, Option[S]) => Option[S]): this.type = { + updateFunction = func + this + } + + def setInitialSessions(initialRDD: RDD[(K, S)]): this.type = { + this.initialSessionRDD = initialRDD + this + } + + def reportAllSession(allSessions: Boolean): this.type = { + this.allSessions = allSessions + this + } + + private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner) + + private[streaming] def getUpdateFunction(): (V, Option[S]) => Option[S] = updateFunction + + private[streaming] def getInitialSessions(): Option[RDD[(K, S)]] = Option(initialSessionRDD) + + private[streaming] def getAllSessions(): Boolean = allSessions + + private[streaming] def validate(): Unit = { + require(updateFunction != null) + } +} + +object SessionSpec { + def create[K: ClassTag, V: ClassTag, S: ClassTag]( + updateFunction: (V, Option[S]) => Option[S]): SessionSpec[K, V, S] = { + new SessionSpec[K, V, S].setUpdateFunction(updateFunction) + } +} + + + +// =============================================== +// =============================================== +// ============== PRIVATE CLASSES ================ +// =============================================== +// =============================================== + + + +// ----------------------------------------------- +// --------------- SessionMap stuff -------------- +// ----------------------------------------------- + +/** + * Internal interface for defining the map that keeps track of sessions. + */ +private[streaming] abstract class SessionMap[K: ClassTag, S: ClassTag] extends Serializable { + /** Add or update session data */ + + def put(key: K, session: S): Unit + + /** Get the session data if it exists */ + def get(key: K): Option[S] + + /** Remove a key */ + def remove(key: K): Unit + + /** + * Shallow copy the map to create a new session map. Updates to the new map + * should not mutate `this` map. + */ + def copy(): SessionMap[K, S] + + /** + * Return an iterator of data in this map. If th flag is true, implementations should + * return only the session that were updated since the creation of this map. + */ + def iterator(updatedSessionsOnly: Boolean): Iterator[Session[K, S]] +} + +private[streaming] object SessionMap { + def empty[K: ClassTag, S: ClassTag]: SessionMap[K, S] = new EmptySessionMap[K, S] + + def create[K: ClassTag, S: ClassTag](): SessionMap[K, S] = new HashMapBasedSessionMap[K, S]() +} + +/** Specific implementation of SessionMap interface representing an empty map */ +private[streaming] class EmptySessionMap[K: ClassTag, S: ClassTag] extends SessionMap[K, S] { + override def put(key: K, session: S): Unit = ??? + override def get(key: K): Option[S] = None + override def copy(): SessionMap[K, S] = new EmptySessionMap[K, S] + override def remove(key: K): Unit = { } + override def iterator(updatedSessionsOnly: Boolean): Iterator[Session[K, S]] = Iterator.empty +} + + +/** Specific implementation of the SessionMap interface using a scala mutable HashMap */ +private[streaming] class HashMapBasedSessionMap[K: ClassTag, S: ClassTag]( + parentSessionMap: SessionMap[K, S]) extends SessionMap[K, S] { + + def this() = this(new EmptySessionMap[K, S]) + + import HashMapBasedSessionMap._ + + private val generation: Int = parentSessionMap match { + case map: HashMapBasedSessionMap[_, _] => map.generation + 1 + case _ => 1 + } + + private val internalMap = new mutable.HashMap[K, SessionInfo[S]] + + override def put(key: K, session: S): Unit = { + internalMap.get(key) match { + case Some(sessionInfo) => + sessionInfo.data = session + case None => + internalMap.put(key, new SessionInfo(session)) + } + } + + /** Get the session data if it exists */ + override def get(key: K): Option[S] = { + internalMap.get(key).filter { _.deleted == false }.map { _.data }.orElse(parentSessionMap.get(key)) + } + + /** Remove a key */ + override def remove(key: K): Unit = { + internalMap.put(key, new SessionInfo(get(key).getOrElse(null.asInstanceOf[S]), deleted = true)) + } + + /** + * Return an iterator of data in this map. If th flag is true, implementations should + * return only the session that were updated since the creation of this map. + */ + override def iterator(updatedSessionsOnly: Boolean): Iterator[Session[K, S]] = { + val updatedSessions = internalMap.iterator.map { case (key, sessionInfo) => + Session(key, sessionInfo.data, !sessionInfo.deleted) + } + + def previousSessions = parentSessionMap.iterator(updatedSessionsOnly = false).filter { session => + !internalMap.contains(session.getKey()) + } + + if (updatedSessionsOnly) { + updatedSessions + } else { + previousSessions ++ updatedSessions + } + } + + /** + * Shallow copy the map to create a new session map. Updates to the new map + * should not mutate `this` map. + */ + override def copy(): SessionMap[K, S] = { + doCopy(generation >= HashMapBasedSessionMap.GENERATION_THRESHOLD_FOR_CONSOLIDATION) + } + + def doCopy(consolidate: Boolean): SessionMap[K, S] = { + if (consolidate) { + val newParentMap = new HashMapBasedSessionMap[K, S]() + iterator(updatedSessionsOnly = false).filter { _.isActive }.foreach { case session => + newParentMap.internalMap.put(session.getKey(), SessionInfo(session.getData(), deleted = false)) + } + new HashMapBasedSessionMap[K, S](newParentMap) + } else { + new HashMapBasedSessionMap[K, S](this) + } + } +} + +private[streaming] object HashMapBasedSessionMap { + + case class SessionInfo[SessionDataType](var data: SessionDataType, var deleted: Boolean = false) + + val GENERATION_THRESHOLD_FOR_CONSOLIDATION = 10 +} + + +// ----------------------------------------------- +// --------------- SessionRDD stuff -------------- +// ----------------------------------------------- + +private[streaming] class SessionRDDPartition( + idx: Int, + @transient private var previousSessionRDD: RDD[_], + @transient private var partitionedDataRDD: RDD[_]) extends Partition { + + private[dstream] var previousSessionRDDPartition: Partition = null + private[dstream] var partitionedDataRDDPartition: Partition = null + + override def index: Int = idx + override def hashCode(): Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { + // Update the reference to parent split at the time of task serialization + previousSessionRDDPartition = previousSessionRDD.partitions(index) + partitionedDataRDDPartition = partitionedDataRDD.partitions(index) + oos.defaultWriteObject() + } +} + +private[streaming] class SessionRDD[K: ClassTag, V: ClassTag, S: ClassTag]( + _sc: SparkContext, + private var previousSessionRDD: RDD[SessionMap[K, S]], + private var partitionedDataRDD: RDD[(K, V)], + updateFunction: (V, Option[S]) => Option[S], + timestamp: Long + ) extends RDD[SessionMap[K, S]]( + _sc, + List(new OneToOneDependency(previousSessionRDD), new OneToOneDependency(partitionedDataRDD)) + ) { + + require(partitionedDataRDD.partitioner == previousSessionRDD.partitioner) + + override val partitioner = previousSessionRDD.partitioner + + override def compute(partition: Partition, context: TaskContext): Iterator[SessionMap[K, S]] = { + val sessionRDDPartition = partition.asInstanceOf[SessionRDDPartition] + val prevSessionIterator = previousSessionRDD.iterator( + sessionRDDPartition.previousSessionRDDPartition, context) + val dataIterator = partitionedDataRDD.iterator( + sessionRDDPartition.partitionedDataRDDPartition, context) + + require(prevSessionIterator.hasNext) + + val sessionMap = prevSessionIterator.next().copy() + dataIterator.foreach { case (key, value) => + val prevState = sessionMap.get(key) + val newState = updateFunction(value, prevState) + if (newState.isDefined) { + sessionMap.put(key, newState.get) + } else { + sessionMap.remove(key) + } + } + Iterator(sessionMap) + } + + override protected def getPartitions: Array[Partition] = { + Array.tabulate(previousSessionRDD.partitions.length) { i => + new SessionRDDPartition(i, previousSessionRDD, partitionedDataRDD)} + } +} + +private[streaming] object SessionRDD { + def createFromPairRDD[K: ClassTag, S: ClassTag]( + pairRDD: RDD[(K, S)], partitioner: Partitioner): RDD[SessionMap[K, S]] = { + + val createStateMap = (iterator: Iterator[(K, S)]) => { + val newSessionMap = SessionMap.create[K, S]() + iterator.foreach { case (key, state) => newSessionMap.put(key, state) } + Iterator(newSessionMap) + } + pairRDD.partitionBy(partitioner).mapPartitions[SessionMap[K, S]]( + createStateMap, preservesPartitioning = true) + } +} + + +// ----------------------------------------------- +// ---------------- SessionDStream --------------- +// ----------------------------------------------- + + +private[streaming] class SessionDStream[K: ClassTag, V: ClassTag, S: ClassTag]( + parent: DStream[(K, V)], sessionSpec: SessionSpec[K, V, S]) + extends DStream[SessionMap[K, S]](parent.context) { + + sessionSpec.validate() + persist(StorageLevel.DISK_ONLY) + + private val partitioner = sessionSpec.getPartitioner().getOrElse( + new HashPartitioner(ssc.sc.defaultParallelism)) + + private val updateFunction = sessionSpec.getUpdateFunction() + + override def slideDuration: Duration = parent.slideDuration + + override def dependencies: List[DStream[_]] = List(parent) + + override val mustCheckpoint = true + + /** Method that generates a RDD for the given time */ + override def compute(validTime: Time): Option[RDD[SessionMap[K, S]]] = { + val previousSessionMapRDD = getOrCompute(validTime - slideDuration).getOrElse { + SessionRDD.createFromPairRDD[K, S]( + sessionSpec.getInitialSessions().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner + ) + } + val newDataRDD = parent.getOrCompute(validTime).get + val partitionedDataRDD = newDataRDD.partitionBy(partitioner) + Some(new SessionRDD( + ssc.sparkContext, previousSessionMapRDD, partitionedDataRDD, + updateFunction, validTime.milliseconds)) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 255376807c95..24a834e4dbac 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -22,13 +22,11 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.language.existentials import scala.reflect.ClassTag -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} +import org.apache.spark.streaming.dstream.{Session, SessionSpec, DStream, WindowedDStream} import org.apache.spark.util.{Clock, ManualClock} -import org.apache.spark.HashPartitioner +import org.apache.spark.{HashPartitioner, SparkConf, SparkException} class BasicOperationsSuite extends TestSuiteBase { test("map") { @@ -631,6 +629,71 @@ class BasicOperationsSuite extends TestSuiteBase { } } + + test("sessionByKey") { + val inputData = + Seq( + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3)), + Seq(("a", 5)), + Seq() + ).map { _.map { case (key, value) => Session(key, value, true) } } + + val sessionOperation = (s: DStream[String]) => { + val updateFunc = (value: Int, sessionData: Option[Int]) => { + Option(value + sessionData.getOrElse(0)) + } + s.map(x => (x, 1)).sessionByKey(SessionSpec.create[String, Int, Int](updateFunc)) + } + + testOperation(inputData, sessionOperation, outputData, true) + } + + test("sessionByKey - with all sessions") { + val inputData = + Seq( + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ).map { _.map { case (key, value) => Session(key, value, true) } } + + val sessionOperation = (s: DStream[String]) => { + val updateFunc = (value: Int, sessionData: Option[Int]) => { + Option(value + sessionData.getOrElse(0)) + } + s.map(x => (x, 1)).sessionByKey( + SessionSpec.create[String, Int, Int](updateFunc).reportAllSession(true)) + } + + testOperation(inputData, sessionOperation, outputData, true) + } + + /** Test cleanup of RDDs in DStream metadata */ def runCleanupTest[T: ClassTag]( conf2: SparkConf, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/SessionMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/SessionMapSuite.scala new file mode 100644 index 000000000000..29c0b2d68b43 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/SessionMapSuite.scala @@ -0,0 +1,82 @@ +package org.apache.spark.streaming + +import org.apache.spark.SparkFunSuite +import org.apache.spark.streaming.dstream.HashMapBasedSessionMap +import org.apache.spark.streaming.dstream.Session + +class SessionMapSuite extends SparkFunSuite { + test("put, get, remove, iterator") { + val map = new HashMapBasedSessionMap[Int, Int]() + + map.put(1, 100) + assert(map.get(1) === Some(100)) + assert(map.get(2) === None) + map.put(2, 200) + assert(map.iterator(updatedSessionsOnly = true).toSet === + Set(Session(1, 100, true), Session(2, 200, true))) + assert(map.iterator(updatedSessionsOnly = false).toSet === + Set(Session(1, 100, true), Session(2, 200, true))) + + map.remove(1) + assert(map.get(1) === None) + + assert(map.iterator(updatedSessionsOnly = true).toSet === + Set(Session(1, 100, false), Session(2, 200, true))) + assert(map.iterator(updatedSessionsOnly = false).toSet === + Set(Session(1, 100, false), Session(2, 200, true))) + } + + test("put, get, remove, iterator after copy") { + val parentMap = new HashMapBasedSessionMap[Int, Int]() + parentMap.put(1, 100) + parentMap.put(2, 200) + parentMap.remove(1) + + val map = parentMap.copy() + assert(map.iterator(updatedSessionsOnly = true).toSet === Set()) + assert(map.iterator(updatedSessionsOnly = false).toSet === + Set(Session(1, 100, false), Session(2, 200, true))) + + map.put(3, 300) + map.put(4, 400) + map.remove(4) + + assert(map.iterator(updatedSessionsOnly = true).toSet === + Set(Session(3, 300, true), Session(4, 400, false))) + assert(map.iterator(updatedSessionsOnly = false).toSet === + Set(Session(1, 100, false), Session(2, 200, true), + Session(3, 300, true), Session(4, 400, false))) + + assert(parentMap.iterator(updatedSessionsOnly = true).toSet === + Set(Session(1, 100, false), Session(2, 200, true))) + assert(parentMap.iterator(updatedSessionsOnly = false).toSet === + Set(Session(1, 100, false), Session(2, 200, true))) + + map.put(1, 1000) + map.put(2, 2000) + assert(map.iterator(updatedSessionsOnly = true).toSet === + Set(Session(3, 300, true), Session(4, 400, false), + Session(1, 1000, true), Session(2, 2000, true))) + assert(map.iterator(updatedSessionsOnly = false).toSet === + Set(Session(1, 1000, true), Session(2, 2000, true), + Session(3, 300, true), Session(4, 400, false))) + } + + test("copying with consolidation") { + val map1 = new HashMapBasedSessionMap[Int, Int]() + map1.put(1, 100) + map1.put(2, 200) + + val map2 = map1.copy() + map2.put(3, 300) + map2.put(4, 400) + + val map3 = map2.copy() + map3.put(3, 600) + map3.put(4, 700) + + assert(map3.iterator(false).toSet === + map3.asInstanceOf[HashMapBasedSessionMap[Int, Int]].doCopy(true).iterator(false).toSet) + + } +} From ff7731207dc415f951b9451d9b9d3aa2c7ef99cf Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 8 Oct 2015 18:46:11 -0700 Subject: [PATCH 02/26] Renamed SessionMap to SessionStore and fixed checkpointing bug in SessionRDD --- .../streaming/StatefulNetworkWordCount.scala | 5 +- .../streaming/dstream/SessionDStream.scala | 92 ++++++++++--------- ...MapSuite.scala => SessionStoreSuite.scala} | 12 +-- 3 files changed, 58 insertions(+), 51 deletions(-) rename streaming/src/test/scala/org/apache/spark/streaming/{SessionMapSuite.scala => SessionStoreSuite.scala} (86%) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 944ad2be159b..4a970cc5b62b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -55,11 +55,12 @@ object StatefulNetworkWordCount { val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) - }*/ + } + */ val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount") // Create the context with a 1 second batch size - val ssc = new StreamingContext(sparkConf, Seconds(1)) + val ssc = new StreamingContext(sparkConf, Milliseconds(200)) ssc.checkpoint(".") // Initial RDD input to updateStateByKey diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala index 077179efa0a3..f97813e1334b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala @@ -108,13 +108,13 @@ object SessionSpec { // ----------------------------------------------- -// --------------- SessionMap stuff -------------- +// --------------- SessionStore stuff -------------- // ----------------------------------------------- /** * Internal interface for defining the map that keeps track of sessions. */ -private[streaming] abstract class SessionMap[K: ClassTag, S: ClassTag] extends Serializable { +private[streaming] abstract class SessionStore[K: ClassTag, S: ClassTag] extends Serializable { /** Add or update session data */ def put(key: K, session: S): Unit @@ -126,10 +126,10 @@ private[streaming] abstract class SessionMap[K: ClassTag, S: ClassTag] extends S def remove(key: K): Unit /** - * Shallow copy the map to create a new session map. Updates to the new map + * Shallow copy the map to create a new session store. Updates to the new map * should not mutate `this` map. */ - def copy(): SessionMap[K, S] + def copy(): SessionStore[K, S] /** * Return an iterator of data in this map. If th flag is true, implementations should @@ -138,32 +138,32 @@ private[streaming] abstract class SessionMap[K: ClassTag, S: ClassTag] extends S def iterator(updatedSessionsOnly: Boolean): Iterator[Session[K, S]] } -private[streaming] object SessionMap { - def empty[K: ClassTag, S: ClassTag]: SessionMap[K, S] = new EmptySessionMap[K, S] +private[streaming] object SessionStore { + def empty[K: ClassTag, S: ClassTag]: SessionStore[K, S] = new EmptySessionStore[K, S] - def create[K: ClassTag, S: ClassTag](): SessionMap[K, S] = new HashMapBasedSessionMap[K, S]() + def create[K: ClassTag, S: ClassTag](): SessionStore[K, S] = new HashMapBasedSessionStore[K, S]() } -/** Specific implementation of SessionMap interface representing an empty map */ -private[streaming] class EmptySessionMap[K: ClassTag, S: ClassTag] extends SessionMap[K, S] { +/** Specific implementation of SessionStore interface representing an empty map */ +private[streaming] class EmptySessionStore[K: ClassTag, S: ClassTag] extends SessionStore[K, S] { override def put(key: K, session: S): Unit = ??? override def get(key: K): Option[S] = None - override def copy(): SessionMap[K, S] = new EmptySessionMap[K, S] + override def copy(): SessionStore[K, S] = new EmptySessionStore[K, S] override def remove(key: K): Unit = { } override def iterator(updatedSessionsOnly: Boolean): Iterator[Session[K, S]] = Iterator.empty } /** Specific implementation of the SessionMap interface using a scala mutable HashMap */ -private[streaming] class HashMapBasedSessionMap[K: ClassTag, S: ClassTag]( - parentSessionMap: SessionMap[K, S]) extends SessionMap[K, S] { +private[streaming] class HashMapBasedSessionStore[K: ClassTag, S: ClassTag]( + parentSessionStore: SessionStore[K, S]) extends SessionStore[K, S] { - def this() = this(new EmptySessionMap[K, S]) + def this() = this(new EmptySessionStore[K, S]) - import HashMapBasedSessionMap._ + import HashMapBasedSessionStore._ - private val generation: Int = parentSessionMap match { - case map: HashMapBasedSessionMap[_, _] => map.generation + 1 + private val generation: Int = parentSessionStore match { + case map: HashMapBasedSessionStore[_, _] => map.generation + 1 case _ => 1 } @@ -180,7 +180,7 @@ private[streaming] class HashMapBasedSessionMap[K: ClassTag, S: ClassTag]( /** Get the session data if it exists */ override def get(key: K): Option[S] = { - internalMap.get(key).filter { _.deleted == false }.map { _.data }.orElse(parentSessionMap.get(key)) + internalMap.get(key).filter { _.deleted == false }.map { _.data }.orElse(parentSessionStore.get(key)) } /** Remove a key */ @@ -197,7 +197,7 @@ private[streaming] class HashMapBasedSessionMap[K: ClassTag, S: ClassTag]( Session(key, sessionInfo.data, !sessionInfo.deleted) } - def previousSessions = parentSessionMap.iterator(updatedSessionsOnly = false).filter { session => + def previousSessions = parentSessionStore.iterator(updatedSessionsOnly = false).filter { session => !internalMap.contains(session.getKey()) } @@ -209,27 +209,27 @@ private[streaming] class HashMapBasedSessionMap[K: ClassTag, S: ClassTag]( } /** - * Shallow copy the map to create a new session map. Updates to the new map + * Shallow copy the map to create a new session store. Updates to the new map * should not mutate `this` map. */ - override def copy(): SessionMap[K, S] = { - doCopy(generation >= HashMapBasedSessionMap.GENERATION_THRESHOLD_FOR_CONSOLIDATION) + override def copy(): SessionStore[K, S] = { + doCopy(generation >= HashMapBasedSessionStore.GENERATION_THRESHOLD_FOR_CONSOLIDATION) } - def doCopy(consolidate: Boolean): SessionMap[K, S] = { + def doCopy(consolidate: Boolean): SessionStore[K, S] = { if (consolidate) { - val newParentMap = new HashMapBasedSessionMap[K, S]() + val newParentMap = new HashMapBasedSessionStore[K, S]() iterator(updatedSessionsOnly = false).filter { _.isActive }.foreach { case session => newParentMap.internalMap.put(session.getKey(), SessionInfo(session.getData(), deleted = false)) } - new HashMapBasedSessionMap[K, S](newParentMap) + new HashMapBasedSessionStore[K, S](newParentMap) } else { - new HashMapBasedSessionMap[K, S](this) + new HashMapBasedSessionStore[K, S](this) } } } -private[streaming] object HashMapBasedSessionMap { +private[streaming] object HashMapBasedSessionStore { case class SessionInfo[SessionDataType](var data: SessionDataType, var deleted: Boolean = false) @@ -263,11 +263,11 @@ private[streaming] class SessionRDDPartition( private[streaming] class SessionRDD[K: ClassTag, V: ClassTag, S: ClassTag]( _sc: SparkContext, - private var previousSessionRDD: RDD[SessionMap[K, S]], + private var previousSessionRDD: RDD[SessionStore[K, S]], private var partitionedDataRDD: RDD[(K, V)], updateFunction: (V, Option[S]) => Option[S], timestamp: Long - ) extends RDD[SessionMap[K, S]]( + ) extends RDD[SessionStore[K, S]]( _sc, List(new OneToOneDependency(previousSessionRDD), new OneToOneDependency(partitionedDataRDD)) ) { @@ -276,7 +276,7 @@ private[streaming] class SessionRDD[K: ClassTag, V: ClassTag, S: ClassTag]( override val partitioner = previousSessionRDD.partitioner - override def compute(partition: Partition, context: TaskContext): Iterator[SessionMap[K, S]] = { + override def compute(partition: Partition, context: TaskContext): Iterator[SessionStore[K, S]] = { val sessionRDDPartition = partition.asInstanceOf[SessionRDDPartition] val prevSessionIterator = previousSessionRDD.iterator( sessionRDDPartition.previousSessionRDDPartition, context) @@ -285,35 +285,41 @@ private[streaming] class SessionRDD[K: ClassTag, V: ClassTag, S: ClassTag]( require(prevSessionIterator.hasNext) - val sessionMap = prevSessionIterator.next().copy() + val sessionStore = prevSessionIterator.next().copy() dataIterator.foreach { case (key, value) => - val prevState = sessionMap.get(key) + val prevState = sessionStore.get(key) val newState = updateFunction(value, prevState) if (newState.isDefined) { - sessionMap.put(key, newState.get) + sessionStore.put(key, newState.get) } else { - sessionMap.remove(key) + sessionStore.remove(key) } } - Iterator(sessionMap) + Iterator(sessionStore) } override protected def getPartitions: Array[Partition] = { Array.tabulate(previousSessionRDD.partitions.length) { i => new SessionRDDPartition(i, previousSessionRDD, partitionedDataRDD)} } + + override def clearDependencies() { + super.clearDependencies() + previousSessionRDD = null + partitionedDataRDD = null + } } private[streaming] object SessionRDD { def createFromPairRDD[K: ClassTag, S: ClassTag]( - pairRDD: RDD[(K, S)], partitioner: Partitioner): RDD[SessionMap[K, S]] = { + pairRDD: RDD[(K, S)], partitioner: Partitioner): RDD[SessionStore[K, S]] = { val createStateMap = (iterator: Iterator[(K, S)]) => { - val newSessionMap = SessionMap.create[K, S]() - iterator.foreach { case (key, state) => newSessionMap.put(key, state) } - Iterator(newSessionMap) + val newSessionStore = SessionStore.create[K, S]() + iterator.foreach { case (key, state) => newSessionStore.put(key, state) } + Iterator(newSessionStore) } - pairRDD.partitionBy(partitioner).mapPartitions[SessionMap[K, S]]( + pairRDD.partitionBy(partitioner).mapPartitions[SessionStore[K, S]]( createStateMap, preservesPartitioning = true) } } @@ -326,7 +332,7 @@ private[streaming] object SessionRDD { private[streaming] class SessionDStream[K: ClassTag, V: ClassTag, S: ClassTag]( parent: DStream[(K, V)], sessionSpec: SessionSpec[K, V, S]) - extends DStream[SessionMap[K, S]](parent.context) { + extends DStream[SessionStore[K, S]](parent.context) { sessionSpec.validate() persist(StorageLevel.DISK_ONLY) @@ -343,8 +349,8 @@ private[streaming] class SessionDStream[K: ClassTag, V: ClassTag, S: ClassTag]( override val mustCheckpoint = true /** Method that generates a RDD for the given time */ - override def compute(validTime: Time): Option[RDD[SessionMap[K, S]]] = { - val previousSessionMapRDD = getOrCompute(validTime - slideDuration).getOrElse { + override def compute(validTime: Time): Option[RDD[SessionStore[K, S]]] = { + val previousSessionRDD = getOrCompute(validTime - slideDuration).getOrElse { SessionRDD.createFromPairRDD[K, S]( sessionSpec.getInitialSessions().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), partitioner @@ -353,7 +359,7 @@ private[streaming] class SessionDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val newDataRDD = parent.getOrCompute(validTime).get val partitionedDataRDD = newDataRDD.partitionBy(partitioner) Some(new SessionRDD( - ssc.sparkContext, previousSessionMapRDD, partitionedDataRDD, + ssc.sparkContext, previousSessionRDD, partitionedDataRDD, updateFunction, validTime.milliseconds)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/SessionMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala similarity index 86% rename from streaming/src/test/scala/org/apache/spark/streaming/SessionMapSuite.scala rename to streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala index 29c0b2d68b43..cf30fc31fa0c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/SessionMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala @@ -1,12 +1,12 @@ package org.apache.spark.streaming import org.apache.spark.SparkFunSuite -import org.apache.spark.streaming.dstream.HashMapBasedSessionMap +import org.apache.spark.streaming.dstream.HashMapBasedSessionStore import org.apache.spark.streaming.dstream.Session -class SessionMapSuite extends SparkFunSuite { +class SessionStoreSuite extends SparkFunSuite { test("put, get, remove, iterator") { - val map = new HashMapBasedSessionMap[Int, Int]() + val map = new HashMapBasedSessionStore[Int, Int]() map.put(1, 100) assert(map.get(1) === Some(100)) @@ -27,7 +27,7 @@ class SessionMapSuite extends SparkFunSuite { } test("put, get, remove, iterator after copy") { - val parentMap = new HashMapBasedSessionMap[Int, Int]() + val parentMap = new HashMapBasedSessionStore[Int, Int]() parentMap.put(1, 100) parentMap.put(2, 200) parentMap.remove(1) @@ -63,7 +63,7 @@ class SessionMapSuite extends SparkFunSuite { } test("copying with consolidation") { - val map1 = new HashMapBasedSessionMap[Int, Int]() + val map1 = new HashMapBasedSessionStore[Int, Int]() map1.put(1, 100) map1.put(2, 200) @@ -76,7 +76,7 @@ class SessionMapSuite extends SparkFunSuite { map3.put(4, 700) assert(map3.iterator(false).toSet === - map3.asInstanceOf[HashMapBasedSessionMap[Int, Int]].doCopy(true).iterator(false).toSet) + map3.asInstanceOf[HashMapBasedSessionStore[Int, Int]].doCopy(true).iterator(false).toSet) } } From 1fea358dda31ebc5bf089e95f12a9c56165e4555 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 9 Oct 2015 17:25:28 -0700 Subject: [PATCH 03/26] Added OpenHashMapBasedSessionStore --- .../streaming/dstream/SessionDStream.scala | 116 ++++++++++-- .../spark/streaming/SessionStoreSuite.scala | 172 ++++++++++-------- 2 files changed, 205 insertions(+), 83 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala index f97813e1334b..c140e811782b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala @@ -10,6 +10,7 @@ import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.util.Utils +import org.apache.spark.util.collection.OpenHashMap // ================================================== @@ -51,8 +52,8 @@ private[streaming] object Session { } /** Class representing all the specification of session */ -class SessionSpec[K: ClassTag, V: ClassTag, S: ClassTag] private[streaming]() extends Serializable { - @volatile private var updateFunction: (V, Option[S]) => Option[S] = null +class SessionSpec[K: ClassTag, V: ClassTag, S: ClassTag] private[streaming]( + updateFunction: (V, Option[S]) => Option[S]) extends Serializable { @volatile private var partitioner: Partitioner = null @volatile private var initialSessionRDD: RDD[(K, S)] = null @volatile private var allSessions: Boolean = false @@ -62,11 +63,6 @@ class SessionSpec[K: ClassTag, V: ClassTag, S: ClassTag] private[streaming]() ex this } - def setUpdateFunction(func: (V, Option[S]) => Option[S]): this.type = { - updateFunction = func - this - } - def setInitialSessions(initialRDD: RDD[(K, S)]): this.type = { this.initialSessionRDD = initialRDD this @@ -92,8 +88,8 @@ class SessionSpec[K: ClassTag, V: ClassTag, S: ClassTag] private[streaming]() ex object SessionSpec { def create[K: ClassTag, V: ClassTag, S: ClassTag]( - updateFunction: (V, Option[S]) => Option[S]): SessionSpec[K, V, S] = { - new SessionSpec[K, V, S].setUpdateFunction(updateFunction) + updateFunction: (V, Option[S]) => Option[S]): SessionSpec[K, V, S] = { + new SessionSpec[K, V, S](updateFunction) } } @@ -141,7 +137,7 @@ private[streaming] abstract class SessionStore[K: ClassTag, S: ClassTag] extends private[streaming] object SessionStore { def empty[K: ClassTag, S: ClassTag]: SessionStore[K, S] = new EmptySessionStore[K, S] - def create[K: ClassTag, S: ClassTag](): SessionStore[K, S] = new HashMapBasedSessionStore[K, S]() + def create[K: ClassTag, S: ClassTag](): SessionStore[K, S] = new OpenHashMapBasedSessionStore[K, S]() } /** Specific implementation of SessionStore interface representing an empty map */ @@ -231,12 +227,110 @@ private[streaming] class HashMapBasedSessionStore[K: ClassTag, S: ClassTag]( private[streaming] object HashMapBasedSessionStore { - case class SessionInfo[SessionDataType](var data: SessionDataType, var deleted: Boolean = false) + private case class SessionInfo[SessionDataType](var data: SessionDataType, var deleted: Boolean = false) + + val GENERATION_THRESHOLD_FOR_CONSOLIDATION = 10 +} + + +private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( + parentSessionStore: SessionStore[K, S], initialCapacity: Int) extends SessionStore[K, S] { + + def this(initialCapacity: Int) = this(new EmptySessionStore[K, S], initialCapacity) + + def this() = this(64) + + + import OpenHashMapBasedSessionStore._ + + private val generation: Int = parentSessionStore match { + case map: OpenHashMapBasedSessionStore[_, _] => map.generation + 1 + case _ => 1 + } + + private val internalMap = new OpenHashMap[K, SessionInfo[S]]() + + override def put(key: K, session: S): Unit = { + val sessionInfo = internalMap(key) + if (sessionInfo != null) { + sessionInfo.data = session + } else { + internalMap.update(key, new SessionInfo(session)) + } + } + + /** Get the session data if it exists */ + override def get(key: K): Option[S] = { + val sessionInfo = internalMap(key) + if (sessionInfo != null && sessionInfo.deleted == false) { + Some(sessionInfo.data) + } else { + parentSessionStore.get(key) + } + } + + /** Remove a key */ + override def remove(key: K): Unit = { + internalMap.update(key, new SessionInfo(get(key).getOrElse(null.asInstanceOf[S]), deleted = true)) + } + + /** + * Return an iterator of data in this map. If th flag is true, implementations should + * return only the session that were updated since the creation of this map. + */ + override def iterator(updatedSessionsOnly: Boolean): Iterator[Session[K, S]] = { + val updatedSessions = internalMap.iterator.map { case (key, sessionInfo) => + Session(key, sessionInfo.data, !sessionInfo.deleted) + } + + def previousSessions = parentSessionStore.iterator(updatedSessionsOnly = false).filter { session => + !internalMap.contains(session.getKey()) + } + + if (updatedSessionsOnly) { + updatedSessions + } else { + previousSessions ++ updatedSessions + } + } + + /** + * Shallow copy the map to create a new session store. Updates to the new map + * should not mutate `this` map. + */ + override def copy(): SessionStore[K, S] = { + doCopy(generation >= HashMapBasedSessionStore.GENERATION_THRESHOLD_FOR_CONSOLIDATION) + } + + private[streaming] def doCopy(consolidate: Boolean): SessionStore[K, S] = { + if (consolidate) { + val newParentMap = new OpenHashMapBasedSessionStore[K, S](sizeHint) + iterator(updatedSessionsOnly = false).filter { _.isActive }.foreach { case session => + newParentMap.internalMap.update(session.getKey(), SessionInfo(session.getData(), deleted = false)) + } + new HashMapBasedSessionStore[K, S](newParentMap) + } else { + new HashMapBasedSessionStore[K, S](this) + } + } + + private def sizeHint(): Int = internalMap.size + { + parentSessionStore match { + case s: OpenHashMapBasedSessionStore[_, _] => s.sizeHint() + case _ => 0 + } + } +} + +private[streaming] object OpenHashMapBasedSessionStore { + + private case class SessionInfo[SessionDataType](var data: SessionDataType, var deleted: Boolean = false) val GENERATION_THRESHOLD_FOR_CONSOLIDATION = 10 } + // ----------------------------------------------- // --------------- SessionRDD stuff -------------- // ----------------------------------------------- diff --git a/streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala index cf30fc31fa0c..b3016f754c1a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala @@ -1,82 +1,110 @@ package org.apache.spark.streaming +import scala.reflect._ + import org.apache.spark.SparkFunSuite -import org.apache.spark.streaming.dstream.HashMapBasedSessionStore -import org.apache.spark.streaming.dstream.Session +import org.apache.spark.streaming.dstream.{OpenHashMapBasedSessionStore, HashMapBasedSessionStore, Session, SessionStore} class SessionStoreSuite extends SparkFunSuite { - test("put, get, remove, iterator") { - val map = new HashMapBasedSessionStore[Int, Int]() - - map.put(1, 100) - assert(map.get(1) === Some(100)) - assert(map.get(2) === None) - map.put(2, 200) - assert(map.iterator(updatedSessionsOnly = true).toSet === - Set(Session(1, 100, true), Session(2, 200, true))) - assert(map.iterator(updatedSessionsOnly = false).toSet === - Set(Session(1, 100, true), Session(2, 200, true))) - - map.remove(1) - assert(map.get(1) === None) - - assert(map.iterator(updatedSessionsOnly = true).toSet === - Set(Session(1, 100, false), Session(2, 200, true))) - assert(map.iterator(updatedSessionsOnly = false).toSet === - Set(Session(1, 100, false), Session(2, 200, true))) - } - test("put, get, remove, iterator after copy") { - val parentMap = new HashMapBasedSessionStore[Int, Int]() - parentMap.put(1, 100) - parentMap.put(2, 200) - parentMap.remove(1) - - val map = parentMap.copy() - assert(map.iterator(updatedSessionsOnly = true).toSet === Set()) - assert(map.iterator(updatedSessionsOnly = false).toSet === - Set(Session(1, 100, false), Session(2, 200, true))) - - map.put(3, 300) - map.put(4, 400) - map.remove(4) - - assert(map.iterator(updatedSessionsOnly = true).toSet === - Set(Session(3, 300, true), Session(4, 400, false))) - assert(map.iterator(updatedSessionsOnly = false).toSet === - Set(Session(1, 100, false), Session(2, 200, true), - Session(3, 300, true), Session(4, 400, false))) - - assert(parentMap.iterator(updatedSessionsOnly = true).toSet === - Set(Session(1, 100, false), Session(2, 200, true))) - assert(parentMap.iterator(updatedSessionsOnly = false).toSet === - Set(Session(1, 100, false), Session(2, 200, true))) - - map.put(1, 1000) - map.put(2, 2000) - assert(map.iterator(updatedSessionsOnly = true).toSet === - Set(Session(3, 300, true), Session(4, 400, false), - Session(1, 1000, true), Session(2, 2000, true))) - assert(map.iterator(updatedSessionsOnly = false).toSet === - Set(Session(1, 1000, true), Session(2, 2000, true), - Session(3, 300, true), Session(4, 400, false))) + HashMapBasedSessionStoreTester.testStore() + OpenHashMapBasedSessionStoreTester.testStore() + + abstract class SessionStoreTester[StoreType <: SessionStore[Int, Int]: ClassTag] { + + private val clazz = classTag[StoreType].runtimeClass + private val className = clazz.getSimpleName + + protected def newStore(): StoreType + + def testStore(): Unit = { + + test(className + "- put, get, remove, iterator") { + val map = newStore() + + map.put(1, 100) + assert(map.get(1) === Some(100)) + assert(map.get(2) === None) + map.put(2, 200) + assert(map.iterator(updatedSessionsOnly = true).toSet === + Set(Session(1, 100, true), Session(2, 200, true))) + assert(map.iterator(updatedSessionsOnly = false).toSet === + Set(Session(1, 100, true), Session(2, 200, true))) + + map.remove(1) + assert(map.get(1) === None) + + assert(map.iterator(updatedSessionsOnly = true).toSet === + Set(Session(1, 100, false), Session(2, 200, true))) + assert(map.iterator(updatedSessionsOnly = false).toSet === + Set(Session(1, 100, false), Session(2, 200, true))) + } + + test(className + " - put, get, remove, iterator after copy") { + val parentMap = newStore() + parentMap.put(1, 100) + parentMap.put(2, 200) + parentMap.remove(1) + + val map = parentMap.copy() + assert(map.iterator(updatedSessionsOnly = true).toSet === Set()) + assert(map.iterator(updatedSessionsOnly = false).toSet === + Set(Session(1, 100, false), Session(2, 200, true))) + + map.put(3, 300) + map.put(4, 400) + map.remove(4) + + assert(map.iterator(updatedSessionsOnly = true).toSet === + Set(Session(3, 300, true), Session(4, 400, false))) + assert(map.iterator(updatedSessionsOnly = false).toSet === + Set(Session(1, 100, false), Session(2, 200, true), + Session(3, 300, true), Session(4, 400, false))) + + assert(parentMap.iterator(updatedSessionsOnly = true).toSet === + Set(Session(1, 100, false), Session(2, 200, true))) + assert(parentMap.iterator(updatedSessionsOnly = false).toSet === + Set(Session(1, 100, false), Session(2, 200, true))) + + map.put(1, 1000) + map.put(2, 2000) + assert(map.iterator(updatedSessionsOnly = true).toSet === + Set(Session(3, 300, true), Session(4, 400, false), + Session(1, 1000, true), Session(2, 2000, true))) + assert(map.iterator(updatedSessionsOnly = false).toSet === + Set(Session(1, 1000, true), Session(2, 2000, true), + Session(3, 300, true), Session(4, 400, false))) + } + + test(className + " - copying with consolidation") { + val map1 = newStore() + map1.put(1, 100) + map1.put(2, 200) + + val map2 = map1.copy() + map2.put(3, 300) + map2.put(4, 400) + + val map3 = map2.copy() + map3.put(3, 600) + map3.put(4, 700) + + assert(map3.iterator(false).toSet === + map3.asInstanceOf[HashMapBasedSessionStore[Int, Int]].doCopy(true).iterator(false).toSet) + + } + } } - test("copying with consolidation") { - val map1 = new HashMapBasedSessionStore[Int, Int]() - map1.put(1, 100) - map1.put(2, 200) - - val map2 = map1.copy() - map2.put(3, 300) - map2.put(4, 400) - - val map3 = map2.copy() - map3.put(3, 600) - map3.put(4, 700) - - assert(map3.iterator(false).toSet === - map3.asInstanceOf[HashMapBasedSessionStore[Int, Int]].doCopy(true).iterator(false).toSet) + object HashMapBasedSessionStoreTester extends SessionStoreTester[HashMapBasedSessionStore[Int, Int]] { + override protected def newStore(): HashMapBasedSessionStore[Int, Int] = { + new HashMapBasedSessionStore[Int, Int]() + } + } + object OpenHashMapBasedSessionStoreTester extends SessionStoreTester[OpenHashMapBasedSessionStore[Int, Int]] { + override protected def newStore(): OpenHashMapBasedSessionStore[Int, Int] = { + new OpenHashMapBasedSessionStore[Int, Int]() + } } -} +} \ No newline at end of file From 27dbabc8e3d44467fdc10f2333ef42d59c6f85d8 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 12 Oct 2015 22:15:56 -0700 Subject: [PATCH 04/26] Fixed bugs --- .../streaming/StatefulNetworkWordCount.scala | 2 +- .../streaming/dstream/SessionDStream.scala | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 4a970cc5b62b..cd653f57f5af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -60,7 +60,7 @@ object StatefulNetworkWordCount { val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount") // Create the context with a 1 second batch size - val ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val ssc = new StreamingContext(sparkConf, Milliseconds(2000)) ssc.checkpoint(".") // Initial RDD input to updateStateByKey diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala index c140e811782b..8032824e6ccc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala @@ -234,12 +234,13 @@ private[streaming] object HashMapBasedSessionStore { private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( - parentSessionStore: SessionStore[K, S], initialCapacity: Int) extends SessionStore[K, S] { + parentSessionStore: SessionStore[K, S], + initialCapacity: Int = 64 + ) extends SessionStore[K, S] { def this(initialCapacity: Int) = this(new EmptySessionStore[K, S], initialCapacity) - def this() = this(64) - + def this() = this(new EmptySessionStore[K, S]) import OpenHashMapBasedSessionStore._ @@ -248,6 +249,8 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( case _ => 1 } + println("Generation " + generation) + private val internalMap = new OpenHashMap[K, SessionInfo[S]]() override def put(key: K, session: S): Unit = { @@ -299,18 +302,18 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( * should not mutate `this` map. */ override def copy(): SessionStore[K, S] = { - doCopy(generation >= HashMapBasedSessionStore.GENERATION_THRESHOLD_FOR_CONSOLIDATION) + doCopy(generation >= GENERATION_THRESHOLD_FOR_CONSOLIDATION) } private[streaming] def doCopy(consolidate: Boolean): SessionStore[K, S] = { if (consolidate) { - val newParentMap = new OpenHashMapBasedSessionStore[K, S](sizeHint) + val newParentMap = new OpenHashMapBasedSessionStore[K, S](initialCapacity = sizeHint) iterator(updatedSessionsOnly = false).filter { _.isActive }.foreach { case session => newParentMap.internalMap.update(session.getKey(), SessionInfo(session.getData(), deleted = false)) } - new HashMapBasedSessionStore[K, S](newParentMap) + new OpenHashMapBasedSessionStore[K, S](newParentMap) } else { - new HashMapBasedSessionStore[K, S](this) + new OpenHashMapBasedSessionStore[K, S](this) } } @@ -429,7 +432,7 @@ private[streaming] class SessionDStream[K: ClassTag, V: ClassTag, S: ClassTag]( extends DStream[SessionStore[K, S]](parent.context) { sessionSpec.validate() - persist(StorageLevel.DISK_ONLY) + persist(StorageLevel.MEMORY_ONLY) private val partitioner = sessionSpec.getPartitioner().getOrElse( new HashPartitioner(ssc.sc.defaultParallelism)) From 3fc50675b3471c14abded7cbfd23693c81c2b58d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 12 Oct 2015 22:42:52 -0700 Subject: [PATCH 05/26] Made delta chain threshold configurable --- .../streaming/dstream/SessionDStream.scala | 45 +++++++++++-------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala index 8032824e6ccc..0dc9fc701405 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala @@ -137,7 +137,11 @@ private[streaming] abstract class SessionStore[K: ClassTag, S: ClassTag] extends private[streaming] object SessionStore { def empty[K: ClassTag, S: ClassTag]: SessionStore[K, S] = new EmptySessionStore[K, S] - def create[K: ClassTag, S: ClassTag](): SessionStore[K, S] = new OpenHashMapBasedSessionStore[K, S]() + def create[K: ClassTag, S: ClassTag](conf: SparkConf = null): SessionStore[K, S] = { + val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", + OpenHashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD) + new OpenHashMapBasedSessionStore[K, S](64, deltaChainThreshold) + } } /** Specific implementation of SessionStore interface representing an empty map */ @@ -158,9 +162,9 @@ private[streaming] class HashMapBasedSessionStore[K: ClassTag, S: ClassTag]( import HashMapBasedSessionStore._ - private val generation: Int = parentSessionStore match { - case map: HashMapBasedSessionStore[_, _] => map.generation + 1 - case _ => 1 + private val deltaChainLength: Int = parentSessionStore match { + case map: HashMapBasedSessionStore[_, _] => map.deltaChainLength + 1 + case _ => 0 } private val internalMap = new mutable.HashMap[K, SessionInfo[S]] @@ -209,7 +213,7 @@ private[streaming] class HashMapBasedSessionStore[K: ClassTag, S: ClassTag]( * should not mutate `this` map. */ override def copy(): SessionStore[K, S] = { - doCopy(generation >= HashMapBasedSessionStore.GENERATION_THRESHOLD_FOR_CONSOLIDATION) + doCopy(deltaChainLength >= HashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD) } def doCopy(consolidate: Boolean): SessionStore[K, S] = { @@ -229,29 +233,31 @@ private[streaming] object HashMapBasedSessionStore { private case class SessionInfo[SessionDataType](var data: SessionDataType, var deleted: Boolean = false) - val GENERATION_THRESHOLD_FOR_CONSOLIDATION = 10 + val DELTA_CHAIN_LENGTH_THRESHOLD = 10 } private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( parentSessionStore: SessionStore[K, S], + deltaChainThreshold: Int, initialCapacity: Int = 64 ) extends SessionStore[K, S] { - def this(initialCapacity: Int) = this(new EmptySessionStore[K, S], initialCapacity) + def this(initialCapacity: Int, deltaChainThreshold: Int) = + this(new EmptySessionStore[K, S], initialCapacity, deltaChainThreshold) - def this() = this(new EmptySessionStore[K, S]) + def this(deltaChainThreshold: Int) = this(64, deltaChainThreshold) + + def this() = this(OpenHashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD) import OpenHashMapBasedSessionStore._ - private val generation: Int = parentSessionStore match { - case map: OpenHashMapBasedSessionStore[_, _] => map.generation + 1 - case _ => 1 + private val deltaChainLength: Int = parentSessionStore match { + case map: OpenHashMapBasedSessionStore[_, _] => map.deltaChainLength + 1 + case _ => 0 } - println("Generation " + generation) - - private val internalMap = new OpenHashMap[K, SessionInfo[S]]() + private val internalMap = new OpenHashMap[K, SessionInfo[S]](initialCapacity) override def put(key: K, session: S): Unit = { val sessionInfo = internalMap(key) @@ -302,18 +308,19 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( * should not mutate `this` map. */ override def copy(): SessionStore[K, S] = { - doCopy(generation >= GENERATION_THRESHOLD_FOR_CONSOLIDATION) + doCopy(deltaChainLength >= DELTA_CHAIN_LENGTH_THRESHOLD) } private[streaming] def doCopy(consolidate: Boolean): SessionStore[K, S] = { if (consolidate) { - val newParentMap = new OpenHashMapBasedSessionStore[K, S](initialCapacity = sizeHint) + val newParentMap = new OpenHashMapBasedSessionStore[K, S]( + initialCapacity = sizeHint, deltaChainThreshold) iterator(updatedSessionsOnly = false).filter { _.isActive }.foreach { case session => newParentMap.internalMap.update(session.getKey(), SessionInfo(session.getData(), deleted = false)) } - new OpenHashMapBasedSessionStore[K, S](newParentMap) + new OpenHashMapBasedSessionStore[K, S](newParentMap, deltaChainThreshold = deltaChainThreshold) } else { - new OpenHashMapBasedSessionStore[K, S](this) + new OpenHashMapBasedSessionStore[K, S](this, deltaChainThreshold = deltaChainThreshold) } } @@ -329,7 +336,7 @@ private[streaming] object OpenHashMapBasedSessionStore { private case class SessionInfo[SessionDataType](var data: SessionDataType, var deleted: Boolean = false) - val GENERATION_THRESHOLD_FOR_CONSOLIDATION = 10 + val DELTA_CHAIN_LENGTH_THRESHOLD = 10 } From 514eb017406d531bba41070323cd51efe213ccad Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 12 Oct 2015 23:18:05 -0700 Subject: [PATCH 06/26] Fixed NPE --- .../org/apache/spark/streaming/dstream/SessionDStream.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala index 0dc9fc701405..753e67c26d38 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala @@ -137,7 +137,7 @@ private[streaming] abstract class SessionStore[K: ClassTag, S: ClassTag] extends private[streaming] object SessionStore { def empty[K: ClassTag, S: ClassTag]: SessionStore[K, S] = new EmptySessionStore[K, S] - def create[K: ClassTag, S: ClassTag](conf: SparkConf = null): SessionStore[K, S] = { + def create[K: ClassTag, S: ClassTag](conf: SparkConf): SessionStore[K, S] = { val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", OpenHashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD) new OpenHashMapBasedSessionStore[K, S](64, deltaChainThreshold) @@ -419,7 +419,7 @@ private[streaming] object SessionRDD { pairRDD: RDD[(K, S)], partitioner: Partitioner): RDD[SessionStore[K, S]] = { val createStateMap = (iterator: Iterator[(K, S)]) => { - val newSessionStore = SessionStore.create[K, S]() + val newSessionStore = SessionStore.create[K, S](SparkEnv.get.conf) iterator.foreach { case (key, state) => newSessionStore.put(key, state) } Iterator(newSessionStore) } From d5b2bec581cccb97b27878b39914b1e5327c3f86 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 13 Oct 2015 06:22:38 -0700 Subject: [PATCH 07/26] consolidation while checkpointing --- .../streaming/dstream/SessionDStream.scala | 105 +++++++++++++----- .../spark/streaming/SessionStoreSuite.scala | 23 +++- 2 files changed, 98 insertions(+), 30 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala index 753e67c26d38..6541cced1ad3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala @@ -1,6 +1,6 @@ package org.apache.spark.streaming.dstream -import java.io.{IOException, ObjectOutputStream} +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import scala.collection.mutable import scala.reflect.ClassTag @@ -160,8 +160,6 @@ private[streaming] class HashMapBasedSessionStore[K: ClassTag, S: ClassTag]( def this() = this(new EmptySessionStore[K, S]) - import HashMapBasedSessionStore._ - private val deltaChainLength: Int = parentSessionStore match { case map: HashMapBasedSessionStore[_, _] => map.deltaChainLength + 1 case _ => 0 @@ -230,16 +228,13 @@ private[streaming] class HashMapBasedSessionStore[K: ClassTag, S: ClassTag]( } private[streaming] object HashMapBasedSessionStore { - - private case class SessionInfo[SessionDataType](var data: SessionDataType, var deleted: Boolean = false) - val DELTA_CHAIN_LENGTH_THRESHOLD = 10 } private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( - parentSessionStore: SessionStore[K, S], - deltaChainThreshold: Int, + @volatile private var parentSessionStore: SessionStore[K, S], + @volatile private var deltaChainThreshold: Int, initialCapacity: Int = 64 ) extends SessionStore[K, S] { @@ -250,27 +245,20 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( def this() = this(OpenHashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD) - import OpenHashMapBasedSessionStore._ - - private val deltaChainLength: Int = parentSessionStore match { - case map: OpenHashMapBasedSessionStore[_, _] => map.deltaChainLength + 1 - case _ => 0 - } - - private val internalMap = new OpenHashMap[K, SessionInfo[S]](initialCapacity) + private var deltaMap = new OpenHashMap[K, SessionInfo[S]](initialCapacity) override def put(key: K, session: S): Unit = { - val sessionInfo = internalMap(key) + val sessionInfo = deltaMap(key) if (sessionInfo != null) { sessionInfo.data = session } else { - internalMap.update(key, new SessionInfo(session)) + deltaMap.update(key, new SessionInfo(session)) } } /** Get the session data if it exists */ override def get(key: K): Option[S] = { - val sessionInfo = internalMap(key) + val sessionInfo = deltaMap(key) if (sessionInfo != null && sessionInfo.deleted == false) { Some(sessionInfo.data) } else { @@ -280,7 +268,7 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( /** Remove a key */ override def remove(key: K): Unit = { - internalMap.update(key, new SessionInfo(get(key).getOrElse(null.asInstanceOf[S]), deleted = true)) + deltaMap.update(key, new SessionInfo(get(key).getOrElse(null.asInstanceOf[S]), deleted = true)) } /** @@ -288,12 +276,12 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( * return only the session that were updated since the creation of this map. */ override def iterator(updatedSessionsOnly: Boolean): Iterator[Session[K, S]] = { - val updatedSessions = internalMap.iterator.map { case (key, sessionInfo) => + val updatedSessions = deltaMap.iterator.map { case (key, sessionInfo) => Session(key, sessionInfo.data, !sessionInfo.deleted) } def previousSessions = parentSessionStore.iterator(updatedSessionsOnly = false).filter { session => - !internalMap.contains(session.getKey()) + !deltaMap.contains(session.getKey()) } if (updatedSessionsOnly) { @@ -308,9 +296,9 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( * should not mutate `this` map. */ override def copy(): SessionStore[K, S] = { - doCopy(deltaChainLength >= DELTA_CHAIN_LENGTH_THRESHOLD) + new OpenHashMapBasedSessionStore[K, S](this, deltaChainThreshold = deltaChainThreshold) } - +/* private[streaming] def doCopy(consolidate: Boolean): SessionStore[K, S] = { if (consolidate) { val newParentMap = new OpenHashMapBasedSessionStore[K, S]( @@ -323,20 +311,81 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( new OpenHashMapBasedSessionStore[K, S](this, deltaChainThreshold = deltaChainThreshold) } } +*/ + private def deltaChainLength: Int = parentSessionStore match { + case map: OpenHashMapBasedSessionStore[_, _] => map.deltaChainLength + 1 + case _ => 0 + } - private def sizeHint(): Int = internalMap.size + { + private def sizeHint(): Int = deltaMap.size + { parentSessionStore match { case s: OpenHashMapBasedSessionStore[_, _] => s.sizeHint() case _ => 0 } } + + private def writeObject(outputStream: ObjectOutputStream): Unit = { + if (deltaChainLength > deltaChainThreshold) { + val newParentSessionStore = + new OpenHashMapBasedSessionStore[K, S](initialCapacity = sizeHint, deltaChainThreshold) + val iterOfActiveSessions = parentSessionStore.iterator(updatedSessionsOnly = false).filter { + _.isActive + } + + while (iterOfActiveSessions.hasNext) { + val session = iterOfActiveSessions.next() + newParentSessionStore.deltaMap.update( + session.getKey(), SessionInfo(session.getData(), deleted = false)) + } + parentSessionStore = newParentSessionStore + } + outputStream.defaultWriteObject() + + /* + outputStream.writeInt(deltaChainThreshold) + outputStream.writeInt(deltaMap.size) + val deltaMapIterator = deltaMap.iterator + var deltaMapCount = 0 + while (deltaMapIterator.hasNext) { + deltaMapCount += 1 + val keyedSessionInfo = deltaMapIterator.next() + outputStream.writeObject(keyedSessionInfo._1) + outputStream.writeObject(keyedSessionInfo._2) + } + assert(deltaMapCount == deltaMap.size) + */ + } + + private def readObject(inputStream: ObjectInputStream): Unit = { + inputStream.defaultReadObject() + + /* + deltaChainThreshold = inputStream.readInt() + val deltaMapSize = inputStream.readInt() + println(deltaMapSize) + deltaMap = new OpenHashMap[K, SessionInfo[S]]() + var deltaMapCount = 0 + while (deltaMapCount < deltaMapSize) { + val key = inputStream.readObject().asInstanceOf[K] + val sessionInfo = inputStream.readObject().asInstanceOf[SessionInfo[S]] + deltaMap.update(key, sessionInfo) + deltaMapCount += 1 + } + parentSessionStore = inputStream.readObject().asInstanceOf[SessionStore[K, S]] + */ + + } + + } -private[streaming] object OpenHashMapBasedSessionStore { +class Limiter(val num: Int) extends Serializable - private case class SessionInfo[SessionDataType](var data: SessionDataType, var deleted: Boolean = false) +case class SessionInfo[SessionDataType](var data: SessionDataType, var deleted: Boolean = false) - val DELTA_CHAIN_LENGTH_THRESHOLD = 10 +private[streaming] object OpenHashMapBasedSessionStore { + + val DELTA_CHAIN_LENGTH_THRESHOLD = 20 } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala index b3016f754c1a..c3a615c4ea88 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala @@ -3,7 +3,8 @@ package org.apache.spark.streaming import scala.reflect._ import org.apache.spark.SparkFunSuite -import org.apache.spark.streaming.dstream.{OpenHashMapBasedSessionStore, HashMapBasedSessionStore, Session, SessionStore} +import org.apache.spark.streaming.dstream.{HashMapBasedSessionStore, OpenHashMapBasedSessionStore, Session, SessionStore} +import org.apache.spark.util.Utils class SessionStoreSuite extends SparkFunSuite { @@ -75,7 +76,7 @@ class SessionStoreSuite extends SparkFunSuite { Set(Session(1, 1000, true), Session(2, 2000, true), Session(3, 300, true), Session(4, 400, false))) } - + /* test(className + " - copying with consolidation") { val map1 = newStore() map1.put(1, 100) @@ -92,6 +93,24 @@ class SessionStoreSuite extends SparkFunSuite { assert(map3.iterator(false).toSet === map3.asInstanceOf[HashMapBasedSessionStore[Int, Int]].doCopy(true).iterator(false).toSet) + }*/ + + test(className + " - serializing and deserializing") { + val map1 = newStore() + map1.put(1, 100) + map1.put(2, 200) + + val map2 = map1.copy() + map2.put(3, 300) + map2.put(4, 400) + + val map3 = map2.copy() + map3.put(3, 600) + map3.remove(2) + + val map3_ = Utils.deserialize[SessionStore[Int, Int]](Utils.serialize(map3), Thread.currentThread().getContextClassLoader) + assert(map3_.iterator(true).toSet === map3.iterator(true).toSet) + assert(map3_.iterator(false).toSet === map3.iterator(false).toSet) } } } From 58eee1ee9f4960b13df084202ae2ed57e671a744 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 13 Oct 2015 06:40:26 -0700 Subject: [PATCH 08/26] Fixed bug --- .../apache/spark/streaming/dstream/SessionDStream.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala index 6541cced1ad3..8302d369d63c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala @@ -234,14 +234,14 @@ private[streaming] object HashMapBasedSessionStore { private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( @volatile private var parentSessionStore: SessionStore[K, S], - @volatile private var deltaChainThreshold: Int, - initialCapacity: Int = 64 + initialCapacity: Int = 64, + @volatile private var deltaChainThreshold: Int = OpenHashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD ) extends SessionStore[K, S] { def this(initialCapacity: Int, deltaChainThreshold: Int) = - this(new EmptySessionStore[K, S], initialCapacity, deltaChainThreshold) + this(new EmptySessionStore[K, S], initialCapacity = initialCapacity, deltaChainThreshold = deltaChainThreshold) - def this(deltaChainThreshold: Int) = this(64, deltaChainThreshold) + def this(deltaChainThreshold: Int) = this(initialCapacity = 64, deltaChainThreshold = deltaChainThreshold) def this() = this(OpenHashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD) From 672e3e620774aec13e8b81549a0d7bdbb98f8455 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 13 Oct 2015 07:21:27 -0700 Subject: [PATCH 09/26] optimized serialization --- .../streaming/dstream/SessionDStream.scala | 109 +++++++++++++----- 1 file changed, 82 insertions(+), 27 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala index 8302d369d63c..6ad966eff409 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala @@ -1,6 +1,6 @@ package org.apache.spark.streaming.dstream -import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import scala.collection.mutable import scala.reflect.ClassTag @@ -233,9 +233,9 @@ private[streaming] object HashMapBasedSessionStore { private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( - @volatile private var parentSessionStore: SessionStore[K, S], + @transient @volatile private var parentSessionStore: SessionStore[K, S], initialCapacity: Int = 64, - @volatile private var deltaChainThreshold: Int = OpenHashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD + deltaChainThreshold: Int = OpenHashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD ) extends SessionStore[K, S] { def this(initialCapacity: Int, deltaChainThreshold: Int) = @@ -245,7 +245,7 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( def this() = this(OpenHashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD) - private var deltaMap = new OpenHashMap[K, SessionInfo[S]](initialCapacity) + @transient @volatile private var deltaMap = new OpenHashMap[K, SessionInfo[S]](initialCapacity) override def put(key: K, session: S): Unit = { val sessionInfo = deltaMap(key) @@ -324,25 +324,13 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( } } - private def writeObject(outputStream: ObjectOutputStream): Unit = { - if (deltaChainLength > deltaChainThreshold) { - val newParentSessionStore = - new OpenHashMapBasedSessionStore[K, S](initialCapacity = sizeHint, deltaChainThreshold) - val iterOfActiveSessions = parentSessionStore.iterator(updatedSessionsOnly = false).filter { - _.isActive - } - while (iterOfActiveSessions.hasNext) { - val session = iterOfActiveSessions.next() - newParentSessionStore.deltaMap.update( - session.getKey(), SessionInfo(session.getData(), deleted = false)) - } - parentSessionStore = newParentSessionStore - } + + + private def writeObject(outputStream: ObjectOutputStream): Unit = { outputStream.defaultWriteObject() - /* - outputStream.writeInt(deltaChainThreshold) + // Write the deltaMap outputStream.writeInt(deltaMap.size) val deltaMapIterator = deltaMap.iterator var deltaMapCount = 0 @@ -353,16 +341,42 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( outputStream.writeObject(keyedSessionInfo._2) } assert(deltaMapCount == deltaMap.size) - */ + + // Write the parentSessionStore while consolidating + val consolidate = deltaChainLength > deltaChainThreshold + val newParentSessionStore = if (consolidate) { + new OpenHashMapBasedSessionStore[K, S](initialCapacity = sizeHint, deltaChainThreshold) + } else { null } + + val iterOfActiveSessions = parentSessionStore.iterator(updatedSessionsOnly = false).filter { _.isActive } + + var parentSessionCount = 0 + + outputStream.writeInt(sizeHint) + + while(iterOfActiveSessions.hasNext) { + parentSessionCount += 1 + + val session = iterOfActiveSessions.next() + outputStream.writeObject(session.getKey()) + outputStream.writeObject(session.getData()) + + if (consolidate) { + newParentSessionStore.deltaMap.update( + session.getKey(), SessionInfo(session.getData(), deleted = false)) + } + } + val limiterObj = new Limiter(parentSessionCount) + outputStream.writeObject(limiterObj) + if (consolidate) { + parentSessionStore = newParentSessionStore + } } private def readObject(inputStream: ObjectInputStream): Unit = { inputStream.defaultReadObject() - /* - deltaChainThreshold = inputStream.readInt() val deltaMapSize = inputStream.readInt() - println(deltaMapSize) deltaMap = new OpenHashMap[K, SessionInfo[S]]() var deltaMapCount = 0 while (deltaMapCount < deltaMapSize) { @@ -371,11 +385,52 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( deltaMap.update(key, sessionInfo) deltaMapCount += 1 } - parentSessionStore = inputStream.readObject().asInstanceOf[SessionStore[K, S]] - */ + val parentSessionStoreSizeHint = inputStream.readInt() + val newParentSessionStore = new OpenHashMapBasedSessionStore[K, S]( + initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold) + + var parentSessionLoopDone = false + while(!parentSessionLoopDone) { + val obj = inputStream.readObject() + //println("Read: " + obj) + if (obj.isInstanceOf[Limiter]) { + parentSessionLoopDone = true + val expectedCount = obj.asInstanceOf[Limiter].num + assert(expectedCount == newParentSessionStore.deltaMap.size) + } else { + val key = obj.asInstanceOf[K] + val state = inputStream.readObject().asInstanceOf[S] + newParentSessionStore.deltaMap.update( + key, SessionInfo(state, deleted = false)) + } + } + parentSessionStore = newParentSessionStore } +/* + private def writeObject(outputStream: ObjectOutputStream): Unit = { + if (deltaChainLength > deltaChainThreshold) { + val newParentSessionStore = + new OpenHashMapBasedSessionStore[K, S](initialCapacity = sizeHint, deltaChainThreshold) + val iterOfActiveSessions = parentSessionStore.iterator(updatedSessionsOnly = false).filter { + _.isActive + } + + while (iterOfActiveSessions.hasNext) { + val session = iterOfActiveSessions.next() + newParentSessionStore.deltaMap.update( + session.getKey(), SessionInfo(session.getData(), deleted = false)) + } + parentSessionStore = newParentSessionStore + } + outputStream.defaultWriteObject() + } + + private def readObject(inputStream: ObjectInputStream): Unit = { + inputStream.defaultReadObject() + } +*/ } @@ -385,7 +440,7 @@ case class SessionInfo[SessionDataType](var data: SessionDataType, var deleted: private[streaming] object OpenHashMapBasedSessionStore { - val DELTA_CHAIN_LENGTH_THRESHOLD = 20 + val DELTA_CHAIN_LENGTH_THRESHOLD = 10 } From 51465f4bea5da3afe6067f05ce2a1bc965b277b5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 21 Oct 2015 18:52:06 -0700 Subject: [PATCH 10/26] Updated API based on updated design --- .../dstream/PairDStreamFunctions.scala | 9 +- .../streaming/dstream/SessionDStream.scala | 665 ++++++++++-------- .../streaming/BasicOperationsSuite.scala | 45 +- .../spark/streaming/SessionStoreSuite.scala | 129 ---- .../spark/streaming/StateMapSuite.scala | 101 +++ 5 files changed, 495 insertions(+), 454 deletions(-) delete mode 100644 streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index c29593465cf7..29b42b4e35fc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -349,9 +349,12 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) ) } - def sessionByKey[S: ClassTag](sessionSpec: SessionSpec[K, V, S]): DStream[Session[K, S]] = { - new SessionDStream[K, V, S](self, sessionSpec).mapPartitions { partitionIter => - partitionIter.flatMap { _.iterator(!sessionSpec.getAllSessions()) } + def trackStateByKey[S: ClassTag, T: ClassTag](spec: TrackStateSpec[K, V, S, T]): DStream[T] = { + new TrackStateDStream[K, V, S, T]( + self, + spec.asInstanceOf[TrackStateSpecImpl[K, V, S, T]] + ).mapPartitions { partitionIter => + partitionIter.flatMap { _.emittedRecords } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala index 6ad966eff409..32c7e69323c0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala @@ -2,15 +2,17 @@ package org.apache.spark.streaming.dstream import java.io.{ObjectInputStream, IOException, ObjectOutputStream} -import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.util.Utils +import org.apache.spark.util.{CompletionIterator, Utils} import org.apache.spark.util.collection.OpenHashMap +import org.apache.spark.streaming.dstream.OpenHashMapBasedStateMap._ // ================================================== @@ -19,275 +21,299 @@ import org.apache.spark.util.collection.OpenHashMap // ================================================== // ================================================== +sealed abstract class State[S] { + def isDefined(): Boolean + def get(): S + def update(newState: S): Unit + def remove(): Unit + def isTimingOut(): Boolean -/** Represents a session */ -case class Session[K, S] private[streaming]( - private var key: K, private var data: S, private var active: Boolean) { + @inline final def getOrElse[S1 >: S](default: => S1): S1 = + if (isDefined) default else this.get +} - def this() = this(null.asInstanceOf[K], null.asInstanceOf[S], true) - private[streaming] def set(k: K, s: S, a: Boolean): this.type = { - key = k - data = s - active = a - this - } +/** Class representing all the specification of session */ +abstract class TrackStateSpec[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag] + extends Serializable { - /** Get the session key */ - def getKey(): K = key + def initialState(rdd: RDD[(K, S)]): this.type + def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type - /** Get the session value */ - def getData(): S = data + def numPartitions(numPartitions: Int): this.type + def partitioner(partitioner: Partitioner): this.type - /** Whether the session is active */ - def isActive(): Boolean = active -/* - override def toString(): String = { - s"Session[ Key=$key, Data=$session, active=$active ]" - }*/ + def timeout(interval: Duration): this.type } -private[streaming] object Session { +object TrackStateSpec { + def apply[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + trackingFunction: (K, Option[V], State[S]) => Option[T]): TrackStateSpec[K, V, S, T] = { + new TrackStateSpecImpl[K, V, S, T](trackingFunction) + } + + def create[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + trackingFunction: (K, Option[V], State[S]) => Option[T]): TrackStateSpec[K, V, S, T] = { + apply(trackingFunction) + } } + +// =============================================== +// =============================================== +// ============== PRIVATE CLASSES ================ +// =============================================== +// =============================================== + + /** Class representing all the specification of session */ -class SessionSpec[K: ClassTag, V: ClassTag, S: ClassTag] private[streaming]( - updateFunction: (V, Option[S]) => Option[S]) extends Serializable { +private[streaming] case class TrackStateSpecImpl[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + function: (K, Option[V], State[S]) => Option[T]) extends TrackStateSpec[K, V, S, T] { + + require(function != null) + @volatile private var partitioner: Partitioner = null - @volatile private var initialSessionRDD: RDD[(K, S)] = null - @volatile private var allSessions: Boolean = false + @volatile private var initialStateRDD: RDD[(K, S)] = null + @volatile private var timeoutInterval: Duration = null - def setPartition(partitioner: Partitioner): this.type = { - this.partitioner = partitioner + + def initialState(rdd: RDD[(K, S)]): this.type = { + this.initialStateRDD = rdd this } - def setInitialSessions(initialRDD: RDD[(K, S)]): this.type = { - this.initialSessionRDD = initialRDD + def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { + this.initialStateRDD = javaPairRDD.rdd this } - def reportAllSession(allSessions: Boolean): this.type = { - this.allSessions = allSessions + + def numPartitions(numPartitions: Int): this.type = { + this.partitioner(new HashPartitioner(numPartitions)) + this + } + + def partitioner(partitioner: Partitioner): this.type = { + this.partitioner = partitioner this } + def timeout(interval: Duration): this.type = { + this.timeoutInterval = interval + this + } + + // ================= Private Methods ================= + + private[streaming] def getFunction(): (K, Option[V], State[S]) => Option[T] = function + + private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD) + private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner) - private[streaming] def getUpdateFunction(): (V, Option[S]) => Option[S] = updateFunction + private[streaming] def getTimeoutInterval(): Option[Duration] = Option(timeoutInterval) +} + - private[streaming] def getInitialSessions(): Option[RDD[(K, S)]] = Option(initialSessionRDD) +private[streaming] class StateImpl[S] extends State[S] { - private[streaming] def getAllSessions(): Boolean = allSessions + private var state: S = null.asInstanceOf[S] + private var defined: Boolean = true + private var timingOut: Boolean = false + private var updated: Boolean = false + private var removed: Boolean = false - private[streaming] def validate(): Unit = { - require(updateFunction != null) + // ========= Public API ========= + def isDefined(): Boolean = { + defined } -} -object SessionSpec { - def create[K: ClassTag, V: ClassTag, S: ClassTag]( - updateFunction: (V, Option[S]) => Option[S]): SessionSpec[K, V, S] = { - new SessionSpec[K, V, S](updateFunction) + def get(): S = { + null.asInstanceOf[S] } -} + def update(newState: S): Unit = { + require(!removed, "Cannot update the state after it has been removed") + require(!timingOut, "Cannot update the state that is timing out") + updated = true + state = newState + } + def isTimingOut(): Boolean = { + timingOut + } -// =============================================== -// =============================================== -// ============== PRIVATE CLASSES ================ -// =============================================== -// =============================================== + def remove(): Unit = { + require(!timingOut, "Cannot remove the state that is timing out") + removed = true + } + + // ========= Internal API ========= + + def isRemoved(): Boolean = { + removed + } + + def isUpdated(): Boolean = { + updated + } + + def wrap(optionalState: Option[S]): Unit = { + optionalState match { + case Some(newState) => + this.state = newState + defined = true + + case None => + this.state = null.asInstanceOf[S] + defined = false + } + timingOut = false + removed = false + updated = false + } + + def wrapTiminoutState(newState: S): Unit = { + this.state = newState + defined = true + timingOut = true + removed = false + updated = false + } + + +} // ----------------------------------------------- -// --------------- SessionStore stuff -------------- +// --------------- StateMap stuff -------------- // ----------------------------------------------- -/** - * Internal interface for defining the map that keeps track of sessions. - */ -private[streaming] abstract class SessionStore[K: ClassTag, S: ClassTag] extends Serializable { - /** Add or update session data */ +/** Internal interface for defining the map that keeps track of sessions. */ +private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable { - def put(key: K, session: S): Unit - - /** Get the session data if it exists */ + /** Get the state for a key if it exists */ def get(key: K): Option[S] + /** Get all the keys and states whose updated time is older than the give threshold time */ + def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] + + /** Get all the keys and states in this map. */ + def getAll(): Iterator[(K, S, Long)] + + /** Add or update state */ + def put(key: K, state: S, updatedTime: Long): Unit + /** Remove a key */ def remove(key: K): Unit /** - * Shallow copy the map to create a new session store. Updates to the new map - * should not mutate `this` map. + * Shallow copy `this` map to create a new state map. + * Updates to the new map should not mutate `this` map. */ - def copy(): SessionStore[K, S] + def copy(): StateMap[K, S] - /** - * Return an iterator of data in this map. If th flag is true, implementations should - * return only the session that were updated since the creation of this map. - */ - def iterator(updatedSessionsOnly: Boolean): Iterator[Session[K, S]] + def toDebugString(): String = toString() } -private[streaming] object SessionStore { - def empty[K: ClassTag, S: ClassTag]: SessionStore[K, S] = new EmptySessionStore[K, S] +private[streaming] object StateMap { + def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S] - def create[K: ClassTag, S: ClassTag](conf: SparkConf): SessionStore[K, S] = { + def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", - OpenHashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD) - new OpenHashMapBasedSessionStore[K, S](64, deltaChainThreshold) + DELTA_CHAIN_LENGTH_THRESHOLD) + new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold) } } /** Specific implementation of SessionStore interface representing an empty map */ -private[streaming] class EmptySessionStore[K: ClassTag, S: ClassTag] extends SessionStore[K, S] { - override def put(key: K, session: S): Unit = ??? +private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { + override def put(key: K, session: S, updateTime: Long): Unit = ??? override def get(key: K): Option[S] = None - override def copy(): SessionStore[K, S] = new EmptySessionStore[K, S] + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty + override def copy(): StateMap[K, S] = new EmptyStateMap[K, S] override def remove(key: K): Unit = { } - override def iterator(updatedSessionsOnly: Boolean): Iterator[Session[K, S]] = Iterator.empty + override def getAll(): Iterator[(K, S, Long)] = Iterator.empty + override def toDebugString(): String = "" } +/** Implementation of StateMap based on Spark's OpenHashMap */ +private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( + @transient @volatile private var parentStateMap: StateMap[K, S], + initialCapacity: Int = 64, + deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD + ) extends StateMap[K, S] { self => -/** Specific implementation of the SessionMap interface using a scala mutable HashMap */ -private[streaming] class HashMapBasedSessionStore[K: ClassTag, S: ClassTag]( - parentSessionStore: SessionStore[K, S]) extends SessionStore[K, S] { - - def this() = this(new EmptySessionStore[K, S]) + def this(initialCapacity: Int, deltaChainThreshold: Int) = this( + new EmptyStateMap[K, S], + initialCapacity = initialCapacity, + deltaChainThreshold = deltaChainThreshold) - private val deltaChainLength: Int = parentSessionStore match { - case map: HashMapBasedSessionStore[_, _] => map.deltaChainLength + 1 - case _ => 0 - } + def this(deltaChainThreshold: Int) = this( + initialCapacity = 64, deltaChainThreshold = deltaChainThreshold) - private val internalMap = new mutable.HashMap[K, SessionInfo[S]] + def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) - override def put(key: K, session: S): Unit = { - internalMap.get(key) match { - case Some(sessionInfo) => - sessionInfo.data = session - case None => - internalMap.put(key, new SessionInfo(session)) - } - } + @transient @volatile private var deltaMap = + new OpenHashMap[K, StateInfo[S]](initialCapacity) /** Get the session data if it exists */ override def get(key: K): Option[S] = { - internalMap.get(key).filter { _.deleted == false }.map { _.data }.orElse(parentSessionStore.get(key)) - } - - /** Remove a key */ - override def remove(key: K): Unit = { - internalMap.put(key, new SessionInfo(get(key).getOrElse(null.asInstanceOf[S]), deleted = true)) - } - - /** - * Return an iterator of data in this map. If th flag is true, implementations should - * return only the session that were updated since the creation of this map. - */ - override def iterator(updatedSessionsOnly: Boolean): Iterator[Session[K, S]] = { - val updatedSessions = internalMap.iterator.map { case (key, sessionInfo) => - Session(key, sessionInfo.data, !sessionInfo.deleted) - } - - def previousSessions = parentSessionStore.iterator(updatedSessionsOnly = false).filter { session => - !internalMap.contains(session.getKey()) - } - - if (updatedSessionsOnly) { - updatedSessions + val stateInfo = deltaMap(key) + if (stateInfo != null && !stateInfo.deleted) { + Some(stateInfo.data) } else { - previousSessions ++ updatedSessions + parentStateMap.get(key) } } - /** - * Shallow copy the map to create a new session store. Updates to the new map - * should not mutate `this` map. - */ - override def copy(): SessionStore[K, S] = { - doCopy(deltaChainLength >= HashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD) - } + /** Get all the keys and states whose updated time is older than the give threshold time */ + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = { + val oldStates = parentStateMap.getByTime(threshUpdatedTime).filter { case (key, value, _) => + !deltaMap.contains(key) + } - def doCopy(consolidate: Boolean): SessionStore[K, S] = { - if (consolidate) { - val newParentMap = new HashMapBasedSessionStore[K, S]() - iterator(updatedSessionsOnly = false).filter { _.isActive }.foreach { case session => - newParentMap.internalMap.put(session.getKey(), SessionInfo(session.getData(), deleted = false)) - } - new HashMapBasedSessionStore[K, S](newParentMap) - } else { - new HashMapBasedSessionStore[K, S](this) + val updatedStates = deltaMap.iterator.flatMap { case (key, stateInfo) => + if (! stateInfo.deleted && stateInfo.updateTime < threshUpdatedTime) { + Some((key, stateInfo.data, stateInfo.updateTime)) + } else None } + oldStates ++ updatedStates } -} - -private[streaming] object HashMapBasedSessionStore { - val DELTA_CHAIN_LENGTH_THRESHOLD = 10 -} - - -private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( - @transient @volatile private var parentSessionStore: SessionStore[K, S], - initialCapacity: Int = 64, - deltaChainThreshold: Int = OpenHashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD - ) extends SessionStore[K, S] { - - def this(initialCapacity: Int, deltaChainThreshold: Int) = - this(new EmptySessionStore[K, S], initialCapacity = initialCapacity, deltaChainThreshold = deltaChainThreshold) - def this(deltaChainThreshold: Int) = this(initialCapacity = 64, deltaChainThreshold = deltaChainThreshold) + /** Get all the keys and states in this map. */ + override def getAll(): Iterator[(K, S, Long)] = { - def this() = this(OpenHashMapBasedSessionStore.DELTA_CHAIN_LENGTH_THRESHOLD) - - @transient @volatile private var deltaMap = new OpenHashMap[K, SessionInfo[S]](initialCapacity) - - override def put(key: K, session: S): Unit = { - val sessionInfo = deltaMap(key) - if (sessionInfo != null) { - sessionInfo.data = session - } else { - deltaMap.update(key, new SessionInfo(session)) + val oldStates = parentStateMap.getAll().filter { case (key, _, _) => + !deltaMap.contains(key) } + + val updatedStates = deltaMap.iterator.filter { ! _._2.deleted }.map { case (key, stateInfo) => + (key, stateInfo.data, stateInfo.updateTime) + } + oldStates ++ updatedStates } - /** Get the session data if it exists */ - override def get(key: K): Option[S] = { - val sessionInfo = deltaMap(key) - if (sessionInfo != null && sessionInfo.deleted == false) { - Some(sessionInfo.data) + /** Add or update state */ + override def put(key: K, state: S, updateTime: Long): Unit = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.update(state, updateTime) } else { - parentSessionStore.get(key) + deltaMap.update(key, new StateInfo(state, updateTime)) } } - /** Remove a key */ + /** Remove a state */ override def remove(key: K): Unit = { - deltaMap.update(key, new SessionInfo(get(key).getOrElse(null.asInstanceOf[S]), deleted = true)) - } - - /** - * Return an iterator of data in this map. If th flag is true, implementations should - * return only the session that were updated since the creation of this map. - */ - override def iterator(updatedSessionsOnly: Boolean): Iterator[Session[K, S]] = { - val updatedSessions = deltaMap.iterator.map { case (key, sessionInfo) => - Session(key, sessionInfo.data, !sessionInfo.deleted) - } - - def previousSessions = parentSessionStore.iterator(updatedSessionsOnly = false).filter { session => - !deltaMap.contains(session.getKey()) - } - - if (updatedSessionsOnly) { - updatedSessions + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.markDeleted() } else { - previousSessions ++ updatedSessions + val newInfo = new StateInfo[S](deleted = true) + deltaMap.update(key, newInfo) } } @@ -295,39 +321,54 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( * Shallow copy the map to create a new session store. Updates to the new map * should not mutate `this` map. */ - override def copy(): SessionStore[K, S] = { - new OpenHashMapBasedSessionStore[K, S](this, deltaChainThreshold = deltaChainThreshold) + override def copy(): StateMap[K, S] = { + new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = deltaChainThreshold) } -/* - private[streaming] def doCopy(consolidate: Boolean): SessionStore[K, S] = { - if (consolidate) { - val newParentMap = new OpenHashMapBasedSessionStore[K, S]( - initialCapacity = sizeHint, deltaChainThreshold) - iterator(updatedSessionsOnly = false).filter { _.isActive }.foreach { case session => - newParentMap.internalMap.update(session.getKey(), SessionInfo(session.getData(), deleted = false)) - } - new OpenHashMapBasedSessionStore[K, S](newParentMap, deltaChainThreshold = deltaChainThreshold) - } else { - new OpenHashMapBasedSessionStore[K, S](this, deltaChainThreshold = deltaChainThreshold) - } + + def shouldCompact: Boolean = { + deltaChainLength >= deltaChainThreshold } -*/ - private def deltaChainLength: Int = parentSessionStore match { - case map: OpenHashMapBasedSessionStore[_, _] => map.deltaChainLength + 1 + + def deltaChainLength: Int = parentStateMap match { + case map: OpenHashMapBasedStateMap[_, _] => map.deltaChainLength + 1 case _ => 0 } - private def sizeHint(): Int = deltaMap.size + { - parentSessionStore match { - case s: OpenHashMapBasedSessionStore[_, _] => s.sizeHint() + def approxSize: Int = deltaMap.size + { + parentStateMap match { + case s: OpenHashMapBasedStateMap[_, _] => s.approxSize case _ => 0 } } + override def toDebugString(): String = { + val tabs = if (deltaChainLength > 0) { + (" " * (deltaChainLength - 1)) +"+--- " + } else "" + parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, "\n" + tabs, "") + } + /* + class CompactParentOnCompletionIterator(iterator: Iterator[(K, S, Long)]) + extends CompletionIterator[(K, S, Long), Iterator[(K, S, Long)]](iterator) { + val newParentStateMap = + new OpenHashMapBasedStateMap[K, S](initialCapacity = approxSize, deltaChainThreshold) + + override def next(): (K, S, Long) = { + val next = super.next() + newParentStateMap.put(next._1, next._2, next._3) + next + } + + override def completion(): Unit = { + self.parentStateMap = newParentStateMap + } + } + */ private def writeObject(outputStream: ObjectOutputStream): Unit = { + outputStream.defaultWriteObject() // Write the deltaMap @@ -336,40 +377,41 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( var deltaMapCount = 0 while (deltaMapIterator.hasNext) { deltaMapCount += 1 - val keyedSessionInfo = deltaMapIterator.next() - outputStream.writeObject(keyedSessionInfo._1) - outputStream.writeObject(keyedSessionInfo._2) + val (key, stateInfo) = deltaMapIterator.next() + outputStream.writeObject(key) + outputStream.writeObject(stateInfo) } assert(deltaMapCount == deltaMap.size) - // Write the parentSessionStore while consolidating - val consolidate = deltaChainLength > deltaChainThreshold - val newParentSessionStore = if (consolidate) { - new OpenHashMapBasedSessionStore[K, S](initialCapacity = sizeHint, deltaChainThreshold) + // Write the parentStateMap while consolidating + val doCompaction = shouldCompact + val newParentSessionStore = if (doCompaction) { + new OpenHashMapBasedStateMap[K, S](initialCapacity = approxSize, deltaChainThreshold) } else { null } - val iterOfActiveSessions = parentSessionStore.iterator(updatedSessionsOnly = false).filter { _.isActive } + val iterOfActiveSessions = parentStateMap.getAll() var parentSessionCount = 0 - outputStream.writeInt(sizeHint) + outputStream.writeInt(approxSize) while(iterOfActiveSessions.hasNext) { parentSessionCount += 1 - val session = iterOfActiveSessions.next() - outputStream.writeObject(session.getKey()) - outputStream.writeObject(session.getData()) + val (key, state, updateTime) = iterOfActiveSessions.next() + outputStream.writeObject(key) + outputStream.writeObject(state) + outputStream.writeLong(updateTime) - if (consolidate) { + if (doCompaction) { newParentSessionStore.deltaMap.update( - session.getKey(), SessionInfo(session.getData(), deleted = false)) + key, StateInfo(state, updateTime, deleted = false)) } } val limiterObj = new Limiter(parentSessionCount) outputStream.writeObject(limiterObj) - if (consolidate) { - parentSessionStore = newParentSessionStore + if (doCompaction) { + parentStateMap = newParentSessionStore } } @@ -377,23 +419,22 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( inputStream.defaultReadObject() val deltaMapSize = inputStream.readInt() - deltaMap = new OpenHashMap[K, SessionInfo[S]]() + deltaMap = new OpenHashMap[K, StateInfo[S]]() var deltaMapCount = 0 while (deltaMapCount < deltaMapSize) { val key = inputStream.readObject().asInstanceOf[K] - val sessionInfo = inputStream.readObject().asInstanceOf[SessionInfo[S]] + val sessionInfo = inputStream.readObject().asInstanceOf[StateInfo[S]] deltaMap.update(key, sessionInfo) deltaMapCount += 1 } val parentSessionStoreSizeHint = inputStream.readInt() - val newParentSessionStore = new OpenHashMapBasedSessionStore[K, S]( + val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold) var parentSessionLoopDone = false while(!parentSessionLoopDone) { val obj = inputStream.readObject() - //println("Read: " + obj) if (obj.isInstanceOf[Limiter]) { parentSessionLoopDone = true val expectedCount = obj.asInstanceOf[Limiter].num @@ -401,11 +442,12 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( } else { val key = obj.asInstanceOf[K] val state = inputStream.readObject().asInstanceOf[S] + val updateTime = inputStream.readLong() newParentSessionStore.deltaMap.update( - key, SessionInfo(state, deleted = false)) + key, StateInfo(state, updateTime, deleted = false)) } } - parentSessionStore = newParentSessionStore + parentStateMap = newParentSessionStore } /* @@ -413,7 +455,7 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( if (deltaChainLength > deltaChainThreshold) { val newParentSessionStore = new OpenHashMapBasedSessionStore[K, S](initialCapacity = sizeHint, deltaChainThreshold) - val iterOfActiveSessions = parentSessionStore.iterator(updatedSessionsOnly = false).filter { + val iterOfActiveSessions = parentStateMap.iterator(updatedSessionsOnly = false).filter { _.isActive } @@ -422,7 +464,7 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( newParentSessionStore.deltaMap.update( session.getKey(), SessionInfo(session.getData(), deleted = false)) } - parentSessionStore = newParentSessionStore + parentStateMap = newParentSessionStore } outputStream.defaultWriteObject() } @@ -431,27 +473,45 @@ private[streaming] class OpenHashMapBasedSessionStore[K: ClassTag, S: ClassTag]( inputStream.defaultReadObject() } */ - } class Limiter(val num: Int) extends Serializable -case class SessionInfo[SessionDataType](var data: SessionDataType, var deleted: Boolean = false) -private[streaming] object OpenHashMapBasedSessionStore { +private[streaming] object OpenHashMapBasedStateMap { + + case class StateInfo[S]( + var data: S = null.asInstanceOf[S], + var updateTime: Long = -1, + var deleted: Boolean = false) { + + def markDeleted(): Unit = { + deleted = true + } - val DELTA_CHAIN_LENGTH_THRESHOLD = 10 + def update(newData: S, newUpdateTime: Long): Unit = { + data = newData + updateTime = newUpdateTime + deleted = false + } + } + + val DELTA_CHAIN_LENGTH_THRESHOLD = 20 } // ----------------------------------------------- -// --------------- SessionRDD stuff -------------- +// --------------- StateRDD stuff -------------- // ----------------------------------------------- -private[streaming] class SessionRDDPartition( +private[streaming] case class TrackStateRDDRecord[K: ClassTag, S: ClassTag, T: ClassTag]( + stateMap: StateMap[K, S], emittedRecords: Seq[T]) + + +private[streaming] class TrackStateRDDPartition( idx: Int, - @transient private var previousSessionRDD: RDD[_], + @transient private var prevStateRDD: RDD[_], @transient private var partitionedDataRDD: RDD[_]) extends Partition { private[dstream] var previousSessionRDDPartition: Partition = null @@ -463,72 +523,103 @@ private[streaming] class SessionRDDPartition( @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { // Update the reference to parent split at the time of task serialization - previousSessionRDDPartition = previousSessionRDD.partitions(index) + previousSessionRDDPartition = prevStateRDD.partitions(index) partitionedDataRDDPartition = partitionedDataRDD.partitions(index) oos.defaultWriteObject() } } -private[streaming] class SessionRDD[K: ClassTag, V: ClassTag, S: ClassTag]( +private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( _sc: SparkContext, - private var previousSessionRDD: RDD[SessionStore[K, S]], + private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]], private var partitionedDataRDD: RDD[(K, V)], - updateFunction: (V, Option[S]) => Option[S], - timestamp: Long - ) extends RDD[SessionStore[K, S]]( + trackingFunction: (K, Option[V], State[S]) => Option[T], + currentTime: Long, timeoutThresholdTime: Option[Long] + ) extends RDD[TrackStateRDDRecord[K, S, T]]( _sc, - List(new OneToOneDependency(previousSessionRDD), new OneToOneDependency(partitionedDataRDD)) + List( + new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD), + new OneToOneDependency(partitionedDataRDD)) ) { - require(partitionedDataRDD.partitioner == previousSessionRDD.partitioner) + @volatile private var doFullScan = false + + require(partitionedDataRDD.partitioner.nonEmpty) + require(partitionedDataRDD.partitioner == prevStateRDD.partitioner) + + override val partitioner = prevStateRDD.partitioner + + override def checkpoint(): Unit = { + super.checkpoint() + doFullScan = true + } - override val partitioner = previousSessionRDD.partitioner + override def compute( + partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = { - override def compute(partition: Partition, context: TaskContext): Iterator[SessionStore[K, S]] = { - val sessionRDDPartition = partition.asInstanceOf[SessionRDDPartition] - val prevSessionIterator = previousSessionRDD.iterator( - sessionRDDPartition.previousSessionRDDPartition, context) + val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition] + val prevStateRDDIterator = prevStateRDD.iterator( + stateRDDPartition.previousSessionRDDPartition, context) val dataIterator = partitionedDataRDD.iterator( - sessionRDDPartition.partitionedDataRDDPartition, context) + stateRDDPartition.partitionedDataRDDPartition, context) + if (!prevStateRDDIterator.hasNext) { + throw new SparkException(s"Could not find state map in previous state RDD") + } + + val newStateMap = prevStateRDDIterator.next().stateMap.copy() + val emittedRecords = new ArrayBuffer[T] - require(prevSessionIterator.hasNext) + val stateWrapper = new StateImpl[S]() - val sessionStore = prevSessionIterator.next().copy() dataIterator.foreach { case (key, value) => - val prevState = sessionStore.get(key) - val newState = updateFunction(value, prevState) - if (newState.isDefined) { - sessionStore.put(key, newState.get) - } else { - sessionStore.remove(key) + stateWrapper.wrap(newStateMap.get(key)) + val emittedRecord = trackingFunction(key, Some(value), stateWrapper) + if (stateWrapper.isRemoved()) { + newStateMap.remove(key) + } else if (stateWrapper.isUpdated()) { + newStateMap.put(key, stateWrapper.get(), currentTime) } + emittedRecords ++= emittedRecord } - Iterator(sessionStore) + + if (doFullScan) { + if (timeoutThresholdTime.isDefined) { + newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => + stateWrapper.wrapTiminoutState(state) + val emittedRecord = trackingFunction(key, None, stateWrapper) + emittedRecords ++= emittedRecord + } + } + } + + Iterator(TrackStateRDDRecord(newStateMap, emittedRecords)) } override protected def getPartitions: Array[Partition] = { - Array.tabulate(previousSessionRDD.partitions.length) { i => - new SessionRDDPartition(i, previousSessionRDD, partitionedDataRDD)} + Array.tabulate(prevStateRDD.partitions.length) { i => + new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)} } override def clearDependencies() { super.clearDependencies() - previousSessionRDD = null + prevStateRDD = null partitionedDataRDD = null } } -private[streaming] object SessionRDD { - def createFromPairRDD[K: ClassTag, S: ClassTag]( - pairRDD: RDD[(K, S)], partitioner: Partitioner): RDD[SessionStore[K, S]] = { +private[streaming] object TrackStateRDD { + def createFromPairRDD[K: ClassTag, S: ClassTag, T: ClassTag]( + pairRDD: RDD[(K, S)], + partitioner: Partitioner, + updateTime: Long): RDD[TrackStateRDDRecord[K, S, T]] = { - val createStateMap = (iterator: Iterator[(K, S)]) => { - val newSessionStore = SessionStore.create[K, S](SparkEnv.get.conf) - iterator.foreach { case (key, state) => newSessionStore.put(key, state) } - Iterator(newSessionStore) + val createRecord = (iterator: Iterator[(K, S)]) => { + val stateMap = StateMap.create[K, S](SparkEnv.get.conf) + iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime) } + Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T])) } - pairRDD.partitionBy(partitioner).mapPartitions[SessionStore[K, S]]( - createStateMap, preservesPartitioning = true) + pairRDD.partitionBy(partitioner).mapPartitions[TrackStateRDDRecord[K, S, T]]( + createRecord, true) } } @@ -538,17 +629,16 @@ private[streaming] object SessionRDD { // ----------------------------------------------- -private[streaming] class SessionDStream[K: ClassTag, V: ClassTag, S: ClassTag]( - parent: DStream[(K, V)], sessionSpec: SessionSpec[K, V, S]) - extends DStream[SessionStore[K, S]](parent.context) { +private[streaming] class TrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + parent: DStream[(K, V)], spec: TrackStateSpecImpl[K, V, S, T]) + extends DStream[TrackStateRDDRecord[K, S, T]](parent.context) { - sessionSpec.validate() persist(StorageLevel.MEMORY_ONLY) - private val partitioner = sessionSpec.getPartitioner().getOrElse( + private val partitioner = spec.getPartitioner().getOrElse( new HashPartitioner(ssc.sc.defaultParallelism)) - private val updateFunction = sessionSpec.getUpdateFunction() + private val trackingFunction = spec.getFunction() override def slideDuration: Duration = parent.slideDuration @@ -557,17 +647,22 @@ private[streaming] class SessionDStream[K: ClassTag, V: ClassTag, S: ClassTag]( override val mustCheckpoint = true /** Method that generates a RDD for the given time */ - override def compute(validTime: Time): Option[RDD[SessionStore[K, S]]] = { - val previousSessionRDD = getOrCompute(validTime - slideDuration).getOrElse { - SessionRDD.createFromPairRDD[K, S]( - sessionSpec.getInitialSessions().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), - partitioner + override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, T]]] = { + val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { + TrackStateRDD.createFromPairRDD[K, S, T]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, + validTime.milliseconds ) } val newDataRDD = parent.getOrCompute(validTime).get val partitionedDataRDD = newDataRDD.partitionBy(partitioner) - Some(new SessionRDD( - ssc.sparkContext, previousSessionRDD, partitionedDataRDD, - updateFunction, validTime.milliseconds)) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds + } + + Some(new TrackStateRDD( + ssc.sparkContext, prevStateRDD, partitionedDataRDD, + trackingFunction, validTime.milliseconds, timeoutThresholdTime)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 24a834e4dbac..75ec238dd926 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.{Session, SessionSpec, DStream, WindowedDStream} +import org.apache.spark.streaming.dstream.{State, TrackStateSpec, DStream, WindowedDStream} import org.apache.spark.util.{Clock, ManualClock} import org.apache.spark.{HashPartitioner, SparkConf, SparkException} @@ -630,7 +630,7 @@ class BasicOperationsSuite extends TestSuiteBase { } - test("sessionByKey") { + test("trackStateByKey with emitted states") { val inputData = Seq( Seq("a"), @@ -649,45 +649,16 @@ class BasicOperationsSuite extends TestSuiteBase { Seq(("a", 4), ("b", 3)), Seq(("a", 5)), Seq() - ).map { _.map { case (key, value) => Session(key, value, true) } } - - val sessionOperation = (s: DStream[String]) => { - val updateFunc = (value: Int, sessionData: Option[Int]) => { - Option(value + sessionData.getOrElse(0)) - } - s.map(x => (x, 1)).sessionByKey(SessionSpec.create[String, Int, Int](updateFunc)) - } - - testOperation(inputData, sessionOperation, outputData, true) - } - - test("sessionByKey - with all sessions") { - val inputData = - Seq( - Seq("a"), - Seq("a", "b"), - Seq("a", "b", "c"), - Seq("a", "b"), - Seq("a"), - Seq() ) - val outputData = - Seq( - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 3), ("b", 2), ("c", 1)), - Seq(("a", 4), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)) - ).map { _.map { case (key, value) => Session(key, value, true) } } - val sessionOperation = (s: DStream[String]) => { - val updateFunc = (value: Int, sessionData: Option[Int]) => { - Option(value + sessionData.getOrElse(0)) + val updateFunc = (key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOrElse(0) + val output = (key, sum) + state.update(sum) + Some(output) } - s.map(x => (x, 1)).sessionByKey( - SessionSpec.create[String, Int, Int](updateFunc).reportAllSession(true)) + s.map(x => (x, 1)).trackStateByKey(TrackStateSpec.create(updateFunc)) } testOperation(inputData, sessionOperation, outputData, true) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala deleted file mode 100644 index c3a615c4ea88..000000000000 --- a/streaming/src/test/scala/org/apache/spark/streaming/SessionStoreSuite.scala +++ /dev/null @@ -1,129 +0,0 @@ -package org.apache.spark.streaming - -import scala.reflect._ - -import org.apache.spark.SparkFunSuite -import org.apache.spark.streaming.dstream.{HashMapBasedSessionStore, OpenHashMapBasedSessionStore, Session, SessionStore} -import org.apache.spark.util.Utils - -class SessionStoreSuite extends SparkFunSuite { - - HashMapBasedSessionStoreTester.testStore() - OpenHashMapBasedSessionStoreTester.testStore() - - abstract class SessionStoreTester[StoreType <: SessionStore[Int, Int]: ClassTag] { - - private val clazz = classTag[StoreType].runtimeClass - private val className = clazz.getSimpleName - - protected def newStore(): StoreType - - def testStore(): Unit = { - - test(className + "- put, get, remove, iterator") { - val map = newStore() - - map.put(1, 100) - assert(map.get(1) === Some(100)) - assert(map.get(2) === None) - map.put(2, 200) - assert(map.iterator(updatedSessionsOnly = true).toSet === - Set(Session(1, 100, true), Session(2, 200, true))) - assert(map.iterator(updatedSessionsOnly = false).toSet === - Set(Session(1, 100, true), Session(2, 200, true))) - - map.remove(1) - assert(map.get(1) === None) - - assert(map.iterator(updatedSessionsOnly = true).toSet === - Set(Session(1, 100, false), Session(2, 200, true))) - assert(map.iterator(updatedSessionsOnly = false).toSet === - Set(Session(1, 100, false), Session(2, 200, true))) - } - - test(className + " - put, get, remove, iterator after copy") { - val parentMap = newStore() - parentMap.put(1, 100) - parentMap.put(2, 200) - parentMap.remove(1) - - val map = parentMap.copy() - assert(map.iterator(updatedSessionsOnly = true).toSet === Set()) - assert(map.iterator(updatedSessionsOnly = false).toSet === - Set(Session(1, 100, false), Session(2, 200, true))) - - map.put(3, 300) - map.put(4, 400) - map.remove(4) - - assert(map.iterator(updatedSessionsOnly = true).toSet === - Set(Session(3, 300, true), Session(4, 400, false))) - assert(map.iterator(updatedSessionsOnly = false).toSet === - Set(Session(1, 100, false), Session(2, 200, true), - Session(3, 300, true), Session(4, 400, false))) - - assert(parentMap.iterator(updatedSessionsOnly = true).toSet === - Set(Session(1, 100, false), Session(2, 200, true))) - assert(parentMap.iterator(updatedSessionsOnly = false).toSet === - Set(Session(1, 100, false), Session(2, 200, true))) - - map.put(1, 1000) - map.put(2, 2000) - assert(map.iterator(updatedSessionsOnly = true).toSet === - Set(Session(3, 300, true), Session(4, 400, false), - Session(1, 1000, true), Session(2, 2000, true))) - assert(map.iterator(updatedSessionsOnly = false).toSet === - Set(Session(1, 1000, true), Session(2, 2000, true), - Session(3, 300, true), Session(4, 400, false))) - } - /* - test(className + " - copying with consolidation") { - val map1 = newStore() - map1.put(1, 100) - map1.put(2, 200) - - val map2 = map1.copy() - map2.put(3, 300) - map2.put(4, 400) - - val map3 = map2.copy() - map3.put(3, 600) - map3.put(4, 700) - - assert(map3.iterator(false).toSet === - map3.asInstanceOf[HashMapBasedSessionStore[Int, Int]].doCopy(true).iterator(false).toSet) - - }*/ - - test(className + " - serializing and deserializing") { - val map1 = newStore() - map1.put(1, 100) - map1.put(2, 200) - - val map2 = map1.copy() - map2.put(3, 300) - map2.put(4, 400) - - val map3 = map2.copy() - map3.put(3, 600) - map3.remove(2) - - val map3_ = Utils.deserialize[SessionStore[Int, Int]](Utils.serialize(map3), Thread.currentThread().getContextClassLoader) - assert(map3_.iterator(true).toSet === map3.iterator(true).toSet) - assert(map3_.iterator(false).toSet === map3.iterator(false).toSet) - } - } - } - - object HashMapBasedSessionStoreTester extends SessionStoreTester[HashMapBasedSessionStore[Int, Int]] { - override protected def newStore(): HashMapBasedSessionStore[Int, Int] = { - new HashMapBasedSessionStore[Int, Int]() - } - } - - object OpenHashMapBasedSessionStoreTester extends SessionStoreTester[OpenHashMapBasedSessionStore[Int, Int]] { - override protected def newStore(): OpenHashMapBasedSessionStore[Int, Int] = { - new OpenHashMapBasedSessionStore[Int, Int]() - } - } -} \ No newline at end of file diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala new file mode 100644 index 000000000000..604d031c93f7 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -0,0 +1,101 @@ +package org.apache.spark.streaming + +import scala.reflect._ +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.streaming.dstream.{StateMap, OpenHashMapBasedStateMap} +import org.apache.spark.util.Utils + +class StateMapSuite extends SparkFunSuite { + + test("OpenHashMapBasedStateMap - basic operations") { + val map = new OpenHashMapBasedStateMap[Int, Int]() + + map.put(1, 100, 10) + assert(map.get(1) === Some(100)) + assert(map.get(2) === None) + map.put(2, 200, 20) + assert(map.getAll().toSet === Set((1, 100, 10), (2, 200, 20))) + + map.remove(1) + assert(map.get(1) === None) + assert(map.getAll().toSet === Set((2, 200, 20))) + } + + test("OpenHashMapBasedStateMap - basic operations after copy") { + val parentMap = new OpenHashMapBasedStateMap[Int, Int]() + parentMap.put(1, 100, 1) + parentMap.put(2, 200, 2) + parentMap.remove(1) + + val map = parentMap.copy() + assert(map.getAll().toSet === Set((2, 200, 2))) + + // Add new items + map.put(3, 300, 3) + map.put(4, 400, 4) + assert(map.getAll().toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + // Remove items + map.remove(4) // remove item added to this map + map.remove(2) // remove item remove in parent map + assert(map.getAll().toSet === Set((3, 300, 3))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + // Update items + map.put(1, 1000, 100) // update item removed in parent map + map.put(2, 2000, 200) // update item added in parent map and removed in this map + map.put(3, 3000, 300) // update item added in this map + map.put(4, 4000, 400) // update item removed in this map + + assert(map.getAll().toSet === + Set((1, 1000, 100), (2, 2000, 200), (3, 3000, 300), (4, 4000, 400))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + } + + test("OpenHashMapBasedStateMap - serializing and deserializing") { + val map1 = new OpenHashMapBasedStateMap[Int, Int]() + map1.put(1, 100, 1) + map1.put(2, 200, 2) + + val map2 = map1.copy() + map2.put(3, 300, 3) + map2.put(4, 400, 4) + + val map3 = map2.copy() + map3.put(3, 600, 3) + map3.remove(2) + + // Do not test compaction + assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + + val map3_ = Utils.deserialize[StateMap[Int, Int]](Utils.serialize(map3), Thread.currentThread().getContextClassLoader) + assert(map3_.getAll().toSet === map3.getAll().toSet) + assert(map3.getAll().forall { case (key, state, _) => map3_.get(key) === Some(state)}) + } + + test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") { + val targetDeltaLength = 10 + val deltaChainThreshold = 5 + + var map = new OpenHashMapBasedStateMap[Int, Int]( + deltaChainThreshold = deltaChainThreshold) + + for(i <- 1 to targetDeltaLength) { + map.put(Random.nextInt(), Random.nextInt(), Random.nextLong()) + map = map.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] + } + assert(map.deltaChainLength > deltaChainThreshold) + assert(map.shouldCompact === true) + + val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]]( + Utils.serialize(map), Thread.currentThread().getContextClassLoader) + assert(deser_map.deltaChainLength < deltaChainThreshold) + assert(deser_map.shouldCompact === false) + assert(deser_map.getAll().toSet === map.getAll().toSet) + assert(map.getAll().forall { case (key, state, _) => deser_map.get(key) === Some(state)}) + + } +} \ No newline at end of file From 10f6a0ecbb56e8aaad50681398bfbda7e1134f92 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 22 Oct 2015 19:41:42 -0700 Subject: [PATCH 11/26] Refactoring the code --- .../streaming/StatefulNetworkWordCount.scala | 15 +- .../org/apache/spark/streaming/State.scala | 116 +++ .../spark/streaming/TrackStateSpec.scala | 94 +++ .../dstream/PairDStreamFunctions.scala | 6 +- .../streaming/dstream/SessionDStream.scala | 668 ------------------ .../dstream/TrackedStateDStream.scala | 217 ++++++ .../spark/streaming/util/StateMap.scala | 276 ++++++++ .../streaming/BasicOperationsSuite.scala | 2 +- 8 files changed, 711 insertions(+), 683 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/State.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackedStateDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index cd653f57f5af..02ba1c2eed0f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -19,8 +19,8 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf +import org.apache.spark.HashPartitioner import org.apache.spark.streaming._ -import org.apache.spark.streaming.dstream.SessionSpec /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every @@ -44,7 +44,6 @@ object StatefulNetworkWordCount { StreamingExamples.setStreamingLogLevels() - /* val updateFunc = (values: Seq[Int], state: Option[Int]) => { val currentCount = values.sum @@ -56,11 +55,10 @@ object StatefulNetworkWordCount { val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) } - */ val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount") // Create the context with a 1 second batch size - val ssc = new StreamingContext(sparkConf, Milliseconds(2000)) + val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") // Initial RDD input to updateStateByKey @@ -74,13 +72,8 @@ object StatefulNetworkWordCount { // Update the cumulative count using updateStateByKey // This will give a Dstream made of state (which is the cumulative count of the words) - - val updateFunc = (value: Int, sessionData: Option[Int]) => { - Option(value + sessionData.getOrElse(0)) - } - - val stateDstream = wordDstream.sessionByKey[Int]( - SessionSpec.create(updateFunc).reportAllSession(true)) + val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc, + new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD) stateDstream.print() ssc.start() ssc.awaitTermination() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala new file mode 100644 index 000000000000..0f660bd94060 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -0,0 +1,116 @@ +package org.apache.spark.streaming + +/** + * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] and + * [[org.apache.spark.streaming.api.java.JavaPairDStream]]. + * {{{ + * + * }}} + */ +sealed abstract class State[S] { + + /** Whether the state already exists */ + def exists(): Boolean + + /** + * Get the state if it exists, otherwise wise it will throw an exception. + * Check with `exists()` whether the state exists or not before calling `get()`. + */ + def get(): S + + /** + * Update the state with a new value. Note that you cannot update the state if the state is + * timing out (that is, `isTimingOut() return true`, or if the state has already been removed by + * `remove()`. + */ + def update(newState: S): Unit + + /** Remove the state if it exists. */ + def remove(): Unit + + /** Is the state going to be timed out by the system after this batch interval */ + def isTimingOut(): Boolean + + /** Get the state if it exists, otherwise return the default value */ + @inline final def getOrElse[S1 >: S](default: => S1): S1 = + if (exists) default else this.get +} + +/** Internal implementation of the [[State]] interface */ +private[streaming] class StateImpl[S] extends State[S] { + + private var state: S = null.asInstanceOf[S] + private var defined: Boolean = true + private var timingOut: Boolean = false + private var updated: Boolean = false + private var removed: Boolean = false + + // ========= Public API ========= + def exists(): Boolean = { + defined + } + + def get(): S = { + null.asInstanceOf[S] + } + + def update(newState: S): Unit = { + require(!removed, "Cannot update the state after it has been removed") + require(!timingOut, "Cannot update the state that is timing out") + updated = true + state = newState + } + + def isTimingOut(): Boolean = { + timingOut + } + + def remove(): Unit = { + require(!timingOut, "Cannot remove the state that is timing out") + removed = true + } + + // ========= Internal API ========= + + /** Whether the state has been marked for removing */ + def isRemoved(): Boolean = { + removed + } + + /** Whether the state has been been updated */ + def isUpdated(): Boolean = { + updated + } + + /** + * Internal method to update the state data and reset internal flags in `this`. + * This method allows `this` object to be reused across many state records. + */ + def wrap(optionalState: Option[S]): Unit = { + optionalState match { + case Some(newState) => + this.state = newState + defined = true + + case None => + this.state = null.asInstanceOf[S] + defined = false + } + timingOut = false + removed = false + updated = false + } + + /** + * Internal method to update the state data and reset internal flags in `this`. + * This method allows `this` object to be reused across many state records. + */ + def wrapTiminoutState(newState: S): Unit = { + this.state = newState + defined = true + timingOut = true + removed = false + updated = false + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala new file mode 100644 index 000000000000..86f6b669367c --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala @@ -0,0 +1,94 @@ +package org.apache.spark.streaming + +import scala.reflect.ClassTag + +import org.apache.spark.{HashPartitioner, Partitioner} +import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.rdd.RDD + + +/** + * Abstract class having all the specifications of DStream.trackStateByKey(). + * Use the `TrackStateSpec.create()` or `TrackStateSpec.create()` to create instances of this class. + * + * {{{ + * TrackStateSpec(trackingFunction) // in Scala + * TrackStateSpec.create(trackingFunction) // in Java + * }}} + */ +sealed abstract class TrackStateSpec[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag] + extends Serializable { + + def initialState(rdd: RDD[(K, S)]): this.type + def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type + + def numPartitions(numPartitions: Int): this.type + def partitioner(partitioner: Partitioner): this.type + + def timeout(interval: Duration): this.type +} + + +/** Builder object for creating instances of TrackStateSpec */ +object TrackStateSpec { + + def apply[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + trackingFunction: (K, Option[V], State[S]) => Option[T]): TrackStateSpec[K, V, S, T] = { + new TrackStateSpecImpl[K, V, S, T](trackingFunction) + } + + def create[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + trackingFunction: (K, Option[V], State[S]) => Option[T]): TrackStateSpec[K, V, S, T] = { + apply(trackingFunction) + } +} + + +/** Internal implementation of [[TrackStateSpec]] interface */ +private[streaming] +case class TrackStateSpecImpl[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + function: (K, Option[V], State[S]) => Option[T]) extends TrackStateSpec[K, V, S, T] { + + require(function != null) + + @volatile private var partitioner: Partitioner = null + @volatile private var initialStateRDD: RDD[(K, S)] = null + @volatile private var timeoutInterval: Duration = null + + + def initialState(rdd: RDD[(K, S)]): this.type = { + this.initialStateRDD = rdd + this + } + + def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { + this.initialStateRDD = javaPairRDD.rdd + this + } + + + def numPartitions(numPartitions: Int): this.type = { + this.partitioner(new HashPartitioner(numPartitions)) + this + } + + def partitioner(partitioner: Partitioner): this.type = { + this.partitioner = partitioner + this + } + + def timeout(interval: Duration): this.type = { + this.timeoutInterval = interval + this + } + + // ================= Private Methods ================= + + private[streaming] def getFunction(): (K, Option[V], State[S]) => Option[T] = function + + private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD) + + private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner) + + private[streaming] def getTimeoutInterval(): Option[Duration] = Option(timeoutInterval) +} \ No newline at end of file diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 29b42b4e35fc..77245923426b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -24,11 +24,11 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.StreamingContext.rddToFileName +import org.apache.spark.streaming.{Duration, Time, TrackStateSpec, TrackStateSpecImpl} import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} +import org.apache.spark.{HashPartitioner, Partitioner} /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. @@ -350,7 +350,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) } def trackStateByKey[S: ClassTag, T: ClassTag](spec: TrackStateSpec[K, V, S, T]): DStream[T] = { - new TrackStateDStream[K, V, S, T]( + new TrackedStateDStream[K, V, S, T]( self, spec.asInstanceOf[TrackStateSpecImpl[K, V, S, T]] ).mapPartitions { partitionIter => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala deleted file mode 100644 index 32c7e69323c0..000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SessionDStream.scala +++ /dev/null @@ -1,668 +0,0 @@ -package org.apache.spark.streaming.dstream - -import java.io.{ObjectInputStream, IOException, ObjectOutputStream} - -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - -import org.apache.spark._ -import org.apache.spark.api.java.JavaPairRDD -import org.apache.spark.rdd.{EmptyRDD, RDD} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.util.{CompletionIterator, Utils} -import org.apache.spark.util.collection.OpenHashMap -import org.apache.spark.streaming.dstream.OpenHashMapBasedStateMap._ - - -// ================================================== -// ================================================== -// ================= PUBLIC CLASSES ================= -// ================================================== -// ================================================== - -sealed abstract class State[S] { - def isDefined(): Boolean - def get(): S - def update(newState: S): Unit - def remove(): Unit - def isTimingOut(): Boolean - - @inline final def getOrElse[S1 >: S](default: => S1): S1 = - if (isDefined) default else this.get -} - - -/** Class representing all the specification of session */ -abstract class TrackStateSpec[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag] - extends Serializable { - - def initialState(rdd: RDD[(K, S)]): this.type - def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type - - def numPartitions(numPartitions: Int): this.type - def partitioner(partitioner: Partitioner): this.type - - def timeout(interval: Duration): this.type -} - -object TrackStateSpec { - - def apply[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - trackingFunction: (K, Option[V], State[S]) => Option[T]): TrackStateSpec[K, V, S, T] = { - new TrackStateSpecImpl[K, V, S, T](trackingFunction) - } - - def create[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - trackingFunction: (K, Option[V], State[S]) => Option[T]): TrackStateSpec[K, V, S, T] = { - apply(trackingFunction) - } -} - - -// =============================================== -// =============================================== -// ============== PRIVATE CLASSES ================ -// =============================================== -// =============================================== - - -/** Class representing all the specification of session */ -private[streaming] case class TrackStateSpecImpl[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - function: (K, Option[V], State[S]) => Option[T]) extends TrackStateSpec[K, V, S, T] { - - require(function != null) - - @volatile private var partitioner: Partitioner = null - @volatile private var initialStateRDD: RDD[(K, S)] = null - @volatile private var timeoutInterval: Duration = null - - - def initialState(rdd: RDD[(K, S)]): this.type = { - this.initialStateRDD = rdd - this - } - - def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { - this.initialStateRDD = javaPairRDD.rdd - this - } - - - def numPartitions(numPartitions: Int): this.type = { - this.partitioner(new HashPartitioner(numPartitions)) - this - } - - def partitioner(partitioner: Partitioner): this.type = { - this.partitioner = partitioner - this - } - - def timeout(interval: Duration): this.type = { - this.timeoutInterval = interval - this - } - - // ================= Private Methods ================= - - private[streaming] def getFunction(): (K, Option[V], State[S]) => Option[T] = function - - private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD) - - private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner) - - private[streaming] def getTimeoutInterval(): Option[Duration] = Option(timeoutInterval) -} - - -private[streaming] class StateImpl[S] extends State[S] { - - private var state: S = null.asInstanceOf[S] - private var defined: Boolean = true - private var timingOut: Boolean = false - private var updated: Boolean = false - private var removed: Boolean = false - - // ========= Public API ========= - def isDefined(): Boolean = { - defined - } - - def get(): S = { - null.asInstanceOf[S] - } - - def update(newState: S): Unit = { - require(!removed, "Cannot update the state after it has been removed") - require(!timingOut, "Cannot update the state that is timing out") - updated = true - state = newState - } - - def isTimingOut(): Boolean = { - timingOut - } - - def remove(): Unit = { - require(!timingOut, "Cannot remove the state that is timing out") - removed = true - } - - // ========= Internal API ========= - - def isRemoved(): Boolean = { - removed - } - - def isUpdated(): Boolean = { - updated - } - - def wrap(optionalState: Option[S]): Unit = { - optionalState match { - case Some(newState) => - this.state = newState - defined = true - - case None => - this.state = null.asInstanceOf[S] - defined = false - } - timingOut = false - removed = false - updated = false - } - - def wrapTiminoutState(newState: S): Unit = { - this.state = newState - defined = true - timingOut = true - removed = false - updated = false - } - - -} - - - -// ----------------------------------------------- -// --------------- StateMap stuff -------------- -// ----------------------------------------------- - -/** Internal interface for defining the map that keeps track of sessions. */ -private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable { - - /** Get the state for a key if it exists */ - def get(key: K): Option[S] - - /** Get all the keys and states whose updated time is older than the give threshold time */ - def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] - - /** Get all the keys and states in this map. */ - def getAll(): Iterator[(K, S, Long)] - - /** Add or update state */ - def put(key: K, state: S, updatedTime: Long): Unit - - /** Remove a key */ - def remove(key: K): Unit - - /** - * Shallow copy `this` map to create a new state map. - * Updates to the new map should not mutate `this` map. - */ - def copy(): StateMap[K, S] - - def toDebugString(): String = toString() -} - -private[streaming] object StateMap { - def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S] - - def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { - val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", - DELTA_CHAIN_LENGTH_THRESHOLD) - new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold) - } -} - -/** Specific implementation of SessionStore interface representing an empty map */ -private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { - override def put(key: K, session: S, updateTime: Long): Unit = ??? - override def get(key: K): Option[S] = None - override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty - override def copy(): StateMap[K, S] = new EmptyStateMap[K, S] - override def remove(key: K): Unit = { } - override def getAll(): Iterator[(K, S, Long)] = Iterator.empty - override def toDebugString(): String = "" -} - -/** Implementation of StateMap based on Spark's OpenHashMap */ -private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( - @transient @volatile private var parentStateMap: StateMap[K, S], - initialCapacity: Int = 64, - deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD - ) extends StateMap[K, S] { self => - - def this(initialCapacity: Int, deltaChainThreshold: Int) = this( - new EmptyStateMap[K, S], - initialCapacity = initialCapacity, - deltaChainThreshold = deltaChainThreshold) - - def this(deltaChainThreshold: Int) = this( - initialCapacity = 64, deltaChainThreshold = deltaChainThreshold) - - def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) - - @transient @volatile private var deltaMap = - new OpenHashMap[K, StateInfo[S]](initialCapacity) - - /** Get the session data if it exists */ - override def get(key: K): Option[S] = { - val stateInfo = deltaMap(key) - if (stateInfo != null && !stateInfo.deleted) { - Some(stateInfo.data) - } else { - parentStateMap.get(key) - } - } - - /** Get all the keys and states whose updated time is older than the give threshold time */ - override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = { - val oldStates = parentStateMap.getByTime(threshUpdatedTime).filter { case (key, value, _) => - !deltaMap.contains(key) - } - - val updatedStates = deltaMap.iterator.flatMap { case (key, stateInfo) => - if (! stateInfo.deleted && stateInfo.updateTime < threshUpdatedTime) { - Some((key, stateInfo.data, stateInfo.updateTime)) - } else None - } - oldStates ++ updatedStates - } - - /** Get all the keys and states in this map. */ - override def getAll(): Iterator[(K, S, Long)] = { - - val oldStates = parentStateMap.getAll().filter { case (key, _, _) => - !deltaMap.contains(key) - } - - val updatedStates = deltaMap.iterator.filter { ! _._2.deleted }.map { case (key, stateInfo) => - (key, stateInfo.data, stateInfo.updateTime) - } - oldStates ++ updatedStates - } - - /** Add or update state */ - override def put(key: K, state: S, updateTime: Long): Unit = { - val stateInfo = deltaMap(key) - if (stateInfo != null) { - stateInfo.update(state, updateTime) - } else { - deltaMap.update(key, new StateInfo(state, updateTime)) - } - } - - /** Remove a state */ - override def remove(key: K): Unit = { - val stateInfo = deltaMap(key) - if (stateInfo != null) { - stateInfo.markDeleted() - } else { - val newInfo = new StateInfo[S](deleted = true) - deltaMap.update(key, newInfo) - } - } - - /** - * Shallow copy the map to create a new session store. Updates to the new map - * should not mutate `this` map. - */ - override def copy(): StateMap[K, S] = { - new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = deltaChainThreshold) - } - - def shouldCompact: Boolean = { - deltaChainLength >= deltaChainThreshold - } - - def deltaChainLength: Int = parentStateMap match { - case map: OpenHashMapBasedStateMap[_, _] => map.deltaChainLength + 1 - case _ => 0 - } - - def approxSize: Int = deltaMap.size + { - parentStateMap match { - case s: OpenHashMapBasedStateMap[_, _] => s.approxSize - case _ => 0 - } - } - - override def toDebugString(): String = { - val tabs = if (deltaChainLength > 0) { - (" " * (deltaChainLength - 1)) +"+--- " - } else "" - parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, "\n" + tabs, "") - } - - /* - class CompactParentOnCompletionIterator(iterator: Iterator[(K, S, Long)]) - extends CompletionIterator[(K, S, Long), Iterator[(K, S, Long)]](iterator) { - - val newParentStateMap = - new OpenHashMapBasedStateMap[K, S](initialCapacity = approxSize, deltaChainThreshold) - - override def next(): (K, S, Long) = { - val next = super.next() - newParentStateMap.put(next._1, next._2, next._3) - next - } - - override def completion(): Unit = { - self.parentStateMap = newParentStateMap - } - } - */ - - private def writeObject(outputStream: ObjectOutputStream): Unit = { - - outputStream.defaultWriteObject() - - // Write the deltaMap - outputStream.writeInt(deltaMap.size) - val deltaMapIterator = deltaMap.iterator - var deltaMapCount = 0 - while (deltaMapIterator.hasNext) { - deltaMapCount += 1 - val (key, stateInfo) = deltaMapIterator.next() - outputStream.writeObject(key) - outputStream.writeObject(stateInfo) - } - assert(deltaMapCount == deltaMap.size) - - // Write the parentStateMap while consolidating - val doCompaction = shouldCompact - val newParentSessionStore = if (doCompaction) { - new OpenHashMapBasedStateMap[K, S](initialCapacity = approxSize, deltaChainThreshold) - } else { null } - - val iterOfActiveSessions = parentStateMap.getAll() - - var parentSessionCount = 0 - - outputStream.writeInt(approxSize) - - while(iterOfActiveSessions.hasNext) { - parentSessionCount += 1 - - val (key, state, updateTime) = iterOfActiveSessions.next() - outputStream.writeObject(key) - outputStream.writeObject(state) - outputStream.writeLong(updateTime) - - if (doCompaction) { - newParentSessionStore.deltaMap.update( - key, StateInfo(state, updateTime, deleted = false)) - } - } - val limiterObj = new Limiter(parentSessionCount) - outputStream.writeObject(limiterObj) - if (doCompaction) { - parentStateMap = newParentSessionStore - } - } - - private def readObject(inputStream: ObjectInputStream): Unit = { - inputStream.defaultReadObject() - - val deltaMapSize = inputStream.readInt() - deltaMap = new OpenHashMap[K, StateInfo[S]]() - var deltaMapCount = 0 - while (deltaMapCount < deltaMapSize) { - val key = inputStream.readObject().asInstanceOf[K] - val sessionInfo = inputStream.readObject().asInstanceOf[StateInfo[S]] - deltaMap.update(key, sessionInfo) - deltaMapCount += 1 - } - - val parentSessionStoreSizeHint = inputStream.readInt() - val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( - initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold) - - var parentSessionLoopDone = false - while(!parentSessionLoopDone) { - val obj = inputStream.readObject() - if (obj.isInstanceOf[Limiter]) { - parentSessionLoopDone = true - val expectedCount = obj.asInstanceOf[Limiter].num - assert(expectedCount == newParentSessionStore.deltaMap.size) - } else { - val key = obj.asInstanceOf[K] - val state = inputStream.readObject().asInstanceOf[S] - val updateTime = inputStream.readLong() - newParentSessionStore.deltaMap.update( - key, StateInfo(state, updateTime, deleted = false)) - } - } - parentStateMap = newParentSessionStore - } - -/* - private def writeObject(outputStream: ObjectOutputStream): Unit = { - if (deltaChainLength > deltaChainThreshold) { - val newParentSessionStore = - new OpenHashMapBasedSessionStore[K, S](initialCapacity = sizeHint, deltaChainThreshold) - val iterOfActiveSessions = parentStateMap.iterator(updatedSessionsOnly = false).filter { - _.isActive - } - - while (iterOfActiveSessions.hasNext) { - val session = iterOfActiveSessions.next() - newParentSessionStore.deltaMap.update( - session.getKey(), SessionInfo(session.getData(), deleted = false)) - } - parentStateMap = newParentSessionStore - } - outputStream.defaultWriteObject() - } - - private def readObject(inputStream: ObjectInputStream): Unit = { - inputStream.defaultReadObject() - } -*/ -} - -class Limiter(val num: Int) extends Serializable - - -private[streaming] object OpenHashMapBasedStateMap { - - case class StateInfo[S]( - var data: S = null.asInstanceOf[S], - var updateTime: Long = -1, - var deleted: Boolean = false) { - - def markDeleted(): Unit = { - deleted = true - } - - def update(newData: S, newUpdateTime: Long): Unit = { - data = newData - updateTime = newUpdateTime - deleted = false - } - } - - val DELTA_CHAIN_LENGTH_THRESHOLD = 20 -} - - - -// ----------------------------------------------- -// --------------- StateRDD stuff -------------- -// ----------------------------------------------- - -private[streaming] case class TrackStateRDDRecord[K: ClassTag, S: ClassTag, T: ClassTag]( - stateMap: StateMap[K, S], emittedRecords: Seq[T]) - - -private[streaming] class TrackStateRDDPartition( - idx: Int, - @transient private var prevStateRDD: RDD[_], - @transient private var partitionedDataRDD: RDD[_]) extends Partition { - - private[dstream] var previousSessionRDDPartition: Partition = null - private[dstream] var partitionedDataRDDPartition: Partition = null - - override def index: Int = idx - override def hashCode(): Int = idx - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { - // Update the reference to parent split at the time of task serialization - previousSessionRDDPartition = prevStateRDD.partitions(index) - partitionedDataRDDPartition = partitionedDataRDD.partitions(index) - oos.defaultWriteObject() - } -} - -private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - _sc: SparkContext, - private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]], - private var partitionedDataRDD: RDD[(K, V)], - trackingFunction: (K, Option[V], State[S]) => Option[T], - currentTime: Long, timeoutThresholdTime: Option[Long] - ) extends RDD[TrackStateRDDRecord[K, S, T]]( - _sc, - List( - new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD), - new OneToOneDependency(partitionedDataRDD)) - ) { - - @volatile private var doFullScan = false - - require(partitionedDataRDD.partitioner.nonEmpty) - require(partitionedDataRDD.partitioner == prevStateRDD.partitioner) - - override val partitioner = prevStateRDD.partitioner - - override def checkpoint(): Unit = { - super.checkpoint() - doFullScan = true - } - - override def compute( - partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = { - - val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition] - val prevStateRDDIterator = prevStateRDD.iterator( - stateRDDPartition.previousSessionRDDPartition, context) - val dataIterator = partitionedDataRDD.iterator( - stateRDDPartition.partitionedDataRDDPartition, context) - if (!prevStateRDDIterator.hasNext) { - throw new SparkException(s"Could not find state map in previous state RDD") - } - - val newStateMap = prevStateRDDIterator.next().stateMap.copy() - val emittedRecords = new ArrayBuffer[T] - - val stateWrapper = new StateImpl[S]() - - dataIterator.foreach { case (key, value) => - stateWrapper.wrap(newStateMap.get(key)) - val emittedRecord = trackingFunction(key, Some(value), stateWrapper) - if (stateWrapper.isRemoved()) { - newStateMap.remove(key) - } else if (stateWrapper.isUpdated()) { - newStateMap.put(key, stateWrapper.get(), currentTime) - } - emittedRecords ++= emittedRecord - } - - if (doFullScan) { - if (timeoutThresholdTime.isDefined) { - newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => - stateWrapper.wrapTiminoutState(state) - val emittedRecord = trackingFunction(key, None, stateWrapper) - emittedRecords ++= emittedRecord - } - } - } - - Iterator(TrackStateRDDRecord(newStateMap, emittedRecords)) - } - - override protected def getPartitions: Array[Partition] = { - Array.tabulate(prevStateRDD.partitions.length) { i => - new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)} - } - - override def clearDependencies() { - super.clearDependencies() - prevStateRDD = null - partitionedDataRDD = null - } -} - -private[streaming] object TrackStateRDD { - def createFromPairRDD[K: ClassTag, S: ClassTag, T: ClassTag]( - pairRDD: RDD[(K, S)], - partitioner: Partitioner, - updateTime: Long): RDD[TrackStateRDDRecord[K, S, T]] = { - - val createRecord = (iterator: Iterator[(K, S)]) => { - val stateMap = StateMap.create[K, S](SparkEnv.get.conf) - iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime) } - Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T])) - } - pairRDD.partitionBy(partitioner).mapPartitions[TrackStateRDDRecord[K, S, T]]( - createRecord, true) - } -} - - -// ----------------------------------------------- -// ---------------- SessionDStream --------------- -// ----------------------------------------------- - - -private[streaming] class TrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - parent: DStream[(K, V)], spec: TrackStateSpecImpl[K, V, S, T]) - extends DStream[TrackStateRDDRecord[K, S, T]](parent.context) { - - persist(StorageLevel.MEMORY_ONLY) - - private val partitioner = spec.getPartitioner().getOrElse( - new HashPartitioner(ssc.sc.defaultParallelism)) - - private val trackingFunction = spec.getFunction() - - override def slideDuration: Duration = parent.slideDuration - - override def dependencies: List[DStream[_]] = List(parent) - - override val mustCheckpoint = true - - /** Method that generates a RDD for the given time */ - override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, T]]] = { - val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { - TrackStateRDD.createFromPairRDD[K, S, T]( - spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), - partitioner, - validTime.milliseconds - ) - } - val newDataRDD = parent.getOrCompute(validTime).get - val partitionedDataRDD = newDataRDD.partitionBy(partitioner) - val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => - (validTime - interval).milliseconds - } - - Some(new TrackStateRDD( - ssc.sparkContext, prevStateRDD, partitionedDataRDD, - trackingFunction, validTime.milliseconds, timeoutThresholdTime)) - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackedStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackedStateDStream.scala new file mode 100644 index 000000000000..e4a6998480a9 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackedStateDStream.scala @@ -0,0 +1,217 @@ +package org.apache.spark.streaming.dstream + +import java.io.{IOException, ObjectOutputStream} + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.streaming.util.StateMap +import org.apache.spark.util.Utils + + +// ================================================== +// ================================================== +// ================= PUBLIC CLASSES ================= +// ================================================== +// ================================================== + + + + + + + + + +// =============================================== +// =============================================== +// ============== PRIVATE CLASSES ================ +// =============================================== +// =============================================== + + + + + + + + +// ----------------------------------------------- +// --------------- StateMap stuff -------------- +// ----------------------------------------------- + + + + + + + +// ----------------------------------------------- +// --------------- StateRDD stuff -------------- +// ----------------------------------------------- + +private[streaming] case class TrackStateRDDRecord[K: ClassTag, S: ClassTag, T: ClassTag]( + stateMap: StateMap[K, S], emittedRecords: Seq[T]) + + +private[streaming] class TrackStateRDDPartition( + idx: Int, + @transient private var prevStateRDD: RDD[_], + @transient private var partitionedDataRDD: RDD[_]) extends Partition { + + private[dstream] var previousSessionRDDPartition: Partition = null + private[dstream] var partitionedDataRDDPartition: Partition = null + + override def index: Int = idx + override def hashCode(): Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { + // Update the reference to parent split at the time of task serialization + previousSessionRDDPartition = prevStateRDD.partitions(index) + partitionedDataRDDPartition = partitionedDataRDD.partitions(index) + oos.defaultWriteObject() + } +} + +private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + _sc: SparkContext, + private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]], + private var partitionedDataRDD: RDD[(K, V)], + trackingFunction: (K, Option[V], State[S]) => Option[T], + currentTime: Long, timeoutThresholdTime: Option[Long] + ) extends RDD[TrackStateRDDRecord[K, S, T]]( + _sc, + List( + new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD), + new OneToOneDependency(partitionedDataRDD)) + ) { + + @volatile private var doFullScan = false + + require(partitionedDataRDD.partitioner.nonEmpty) + require(partitionedDataRDD.partitioner == prevStateRDD.partitioner) + + override val partitioner = prevStateRDD.partitioner + + override def checkpoint(): Unit = { + super.checkpoint() + doFullScan = true + } + + override def compute( + partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = { + + val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition] + val prevStateRDDIterator = prevStateRDD.iterator( + stateRDDPartition.previousSessionRDDPartition, context) + val dataIterator = partitionedDataRDD.iterator( + stateRDDPartition.partitionedDataRDDPartition, context) + if (!prevStateRDDIterator.hasNext) { + throw new SparkException(s"Could not find state map in previous state RDD") + } + + val newStateMap = prevStateRDDIterator.next().stateMap.copy() + val emittedRecords = new ArrayBuffer[T] + + val stateWrapper = new StateImpl[S]() + + dataIterator.foreach { case (key, value) => + stateWrapper.wrap(newStateMap.get(key)) + val emittedRecord = trackingFunction(key, Some(value), stateWrapper) + if (stateWrapper.isRemoved()) { + newStateMap.remove(key) + } else if (stateWrapper.isUpdated()) { + newStateMap.put(key, stateWrapper.get(), currentTime) + } + emittedRecords ++= emittedRecord + } + + if (doFullScan) { + if (timeoutThresholdTime.isDefined) { + newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => + stateWrapper.wrapTiminoutState(state) + val emittedRecord = trackingFunction(key, None, stateWrapper) + emittedRecords ++= emittedRecord + } + } + } + + Iterator(TrackStateRDDRecord(newStateMap, emittedRecords)) + } + + override protected def getPartitions: Array[Partition] = { + Array.tabulate(prevStateRDD.partitions.length) { i => + new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)} + } + + override def clearDependencies() { + super.clearDependencies() + prevStateRDD = null + partitionedDataRDD = null + } +} + +private[streaming] object TrackStateRDD { + def createFromPairRDD[K: ClassTag, S: ClassTag, T: ClassTag]( + pairRDD: RDD[(K, S)], + partitioner: Partitioner, + updateTime: Long): RDD[TrackStateRDDRecord[K, S, T]] = { + + val createRecord = (iterator: Iterator[(K, S)]) => { + val stateMap = StateMap.create[K, S](SparkEnv.get.conf) + iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime) } + Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T])) + } + pairRDD.partitionBy(partitioner).mapPartitions[TrackStateRDDRecord[K, S, T]]( + createRecord, true) + } +} + + +// ----------------------------------------------- +// ---------------- SessionDStream --------------- +// ----------------------------------------------- + + +private[streaming] class TrackedStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + parent: DStream[(K, V)], spec: TrackStateSpecImpl[K, V, S, T]) + extends DStream[TrackStateRDDRecord[K, S, T]](parent.context) { + + persist(StorageLevel.MEMORY_ONLY) + + private val partitioner = spec.getPartitioner().getOrElse( + new HashPartitioner(ssc.sc.defaultParallelism)) + + private val trackingFunction = spec.getFunction() + + override def slideDuration: Duration = parent.slideDuration + + override def dependencies: List[DStream[_]] = List(parent) + + override val mustCheckpoint = true + + /** Method that generates a RDD for the given time */ + override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, T]]] = { + val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { + TrackStateRDD.createFromPairRDD[K, S, T]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, + validTime.milliseconds + ) + } + val newDataRDD = parent.getOrCompute(validTime).get + val partitionedDataRDD = newDataRDD.partitionBy(partitioner) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds + } + + Some(new TrackStateRDD( + ssc.sparkContext, prevStateRDD, partitionedDataRDD, + trackingFunction, validTime.milliseconds, timeoutThresholdTime)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala new file mode 100644 index 000000000000..b46b3003200d --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -0,0 +1,276 @@ +package org.apache.spark.streaming.util + +import java.io.{ObjectInputStream, ObjectOutputStream} + +import scala.reflect.ClassTag + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._ +import org.apache.spark.util.collection.OpenHashMap + +/** Internal interface for defining the map that keeps track of sessions. */ +private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable { + + /** Get the state for a key if it exists */ + def get(key: K): Option[S] + + /** Get all the keys and states whose updated time is older than the give threshold time */ + def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] + + /** Get all the keys and states in this map. */ + def getAll(): Iterator[(K, S, Long)] + + /** Add or update state */ + def put(key: K, state: S, updatedTime: Long): Unit + + /** Remove a key */ + def remove(key: K): Unit + + /** + * Shallow copy `this` map to create a new state map. + * Updates to the new map should not mutate `this` map. + */ + def copy(): StateMap[K, S] + + def toDebugString(): String = toString() +} + +/** Companion object for [[StateMap]], with utility methods */ +private[streaming] object StateMap { + def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S] + + def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { + val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", + DELTA_CHAIN_LENGTH_THRESHOLD) + new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold) + } +} + +/** Specific implementation of SessionStore interface representing an empty map */ +private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { + override def put(key: K, session: S, updateTime: Long): Unit = ??? + override def get(key: K): Option[S] = None + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty + override def copy(): StateMap[K, S] = new EmptyStateMap[K, S] + override def remove(key: K): Unit = { } + override def getAll(): Iterator[(K, S, Long)] = Iterator.empty + override def toDebugString(): String = "" +} + + + +/** Implementation of StateMap based on Spark's OpenHashMap */ +private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( + @transient @volatile private var parentStateMap: StateMap[K, S], + initialCapacity: Int = 64, + deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD + ) extends StateMap[K, S] { self => + + def this(initialCapacity: Int, deltaChainThreshold: Int) = this( + new EmptyStateMap[K, S], + initialCapacity = initialCapacity, + deltaChainThreshold = deltaChainThreshold) + + def this(deltaChainThreshold: Int) = this( + initialCapacity = 64, deltaChainThreshold = deltaChainThreshold) + + def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) + + @transient @volatile private var deltaMap = + new OpenHashMap[K, StateInfo[S]](initialCapacity) + + /** Get the session data if it exists */ + override def get(key: K): Option[S] = { + val stateInfo = deltaMap(key) + if (stateInfo != null && !stateInfo.deleted) { + Some(stateInfo.data) + } else { + parentStateMap.get(key) + } + } + + /** Get all the keys and states whose updated time is older than the give threshold time */ + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = { + val oldStates = parentStateMap.getByTime(threshUpdatedTime).filter { case (key, value, _) => + !deltaMap.contains(key) + } + + val updatedStates = deltaMap.iterator.flatMap { case (key, stateInfo) => + if (! stateInfo.deleted && stateInfo.updateTime < threshUpdatedTime) { + Some((key, stateInfo.data, stateInfo.updateTime)) + } else None + } + oldStates ++ updatedStates + } + + /** Get all the keys and states in this map. */ + override def getAll(): Iterator[(K, S, Long)] = { + + val oldStates = parentStateMap.getAll().filter { case (key, _, _) => + !deltaMap.contains(key) + } + + val updatedStates = deltaMap.iterator.filter { ! _._2.deleted }.map { case (key, stateInfo) => + (key, stateInfo.data, stateInfo.updateTime) + } + oldStates ++ updatedStates + } + + /** Add or update state */ + override def put(key: K, state: S, updateTime: Long): Unit = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.update(state, updateTime) + } else { + deltaMap.update(key, new StateInfo(state, updateTime)) + } + } + + /** Remove a state */ + override def remove(key: K): Unit = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.markDeleted() + } else { + val newInfo = new StateInfo[S](deleted = true) + deltaMap.update(key, newInfo) + } + } + + /** + * Shallow copy the map to create a new session store. Updates to the new map + * should not mutate `this` map. + */ + override def copy(): StateMap[K, S] = { + new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = deltaChainThreshold) + } + + def shouldCompact: Boolean = { + deltaChainLength >= deltaChainThreshold + } + + def deltaChainLength: Int = parentStateMap match { + case map: OpenHashMapBasedStateMap[_, _] => map.deltaChainLength + 1 + case _ => 0 + } + + def approxSize: Int = deltaMap.size + { + parentStateMap match { + case s: OpenHashMapBasedStateMap[_, _] => s.approxSize + case _ => 0 + } + } + + override def toDebugString(): String = { + val tabs = if (deltaChainLength > 0) { + (" " * (deltaChainLength - 1)) +"+--- " + } else "" + parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, "\n" + tabs, "") + } + + private def writeObject(outputStream: ObjectOutputStream): Unit = { + + outputStream.defaultWriteObject() + + // Write the deltaMap + outputStream.writeInt(deltaMap.size) + val deltaMapIterator = deltaMap.iterator + var deltaMapCount = 0 + while (deltaMapIterator.hasNext) { + deltaMapCount += 1 + val (key, stateInfo) = deltaMapIterator.next() + outputStream.writeObject(key) + outputStream.writeObject(stateInfo) + } + assert(deltaMapCount == deltaMap.size) + + // Write the parentStateMap while consolidating + val doCompaction = shouldCompact + val newParentSessionStore = if (doCompaction) { + new OpenHashMapBasedStateMap[K, S](initialCapacity = approxSize, deltaChainThreshold) + } else { null } + + val iterOfActiveSessions = parentStateMap.getAll() + + var parentSessionCount = 0 + + outputStream.writeInt(approxSize) + + while(iterOfActiveSessions.hasNext) { + parentSessionCount += 1 + + val (key, state, updateTime) = iterOfActiveSessions.next() + outputStream.writeObject(key) + outputStream.writeObject(state) + outputStream.writeLong(updateTime) + + if (doCompaction) { + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) + } + } + val limiterObj = new LimitMarker(parentSessionCount) + outputStream.writeObject(limiterObj) + if (doCompaction) { + parentStateMap = newParentSessionStore + } + } + + private def readObject(inputStream: ObjectInputStream): Unit = { + inputStream.defaultReadObject() + + val deltaMapSize = inputStream.readInt() + deltaMap = new OpenHashMap[K, StateInfo[S]]() + var deltaMapCount = 0 + while (deltaMapCount < deltaMapSize) { + val key = inputStream.readObject().asInstanceOf[K] + val sessionInfo = inputStream.readObject().asInstanceOf[StateInfo[S]] + deltaMap.update(key, sessionInfo) + deltaMapCount += 1 + } + + val parentSessionStoreSizeHint = inputStream.readInt() + val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( + initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold) + + var parentSessionLoopDone = false + while(!parentSessionLoopDone) { + val obj = inputStream.readObject() + if (obj.isInstanceOf[LimitMarker]) { + parentSessionLoopDone = true + val expectedCount = obj.asInstanceOf[LimitMarker].num + assert(expectedCount == newParentSessionStore.deltaMap.size) + } else { + val key = obj.asInstanceOf[K] + val state = inputStream.readObject().asInstanceOf[S] + val updateTime = inputStream.readLong() + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) + } + } + parentStateMap = newParentSessionStore + } +} + +private[streaming] object OpenHashMapBasedStateMap { + + case class StateInfo[S]( + var data: S = null.asInstanceOf[S], + var updateTime: Long = -1, + var deleted: Boolean = false) { + + def markDeleted(): Unit = { + deleted = true + } + + def update(newData: S, newUpdateTime: Long): Unit = { + data = newData + updateTime = newUpdateTime + deleted = false + } + } + + class LimitMarker(val num: Int) extends Serializable + + val DELTA_CHAIN_LENGTH_THRESHOLD = 20 +} \ No newline at end of file diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 75ec238dd926..1e5a997314ea 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.{State, TrackStateSpec, DStream, WindowedDStream} +import org.apache.spark.streaming.dstream.{TrackStateSpec, DStream, WindowedDStream} import org.apache.spark.util.{Clock, ManualClock} import org.apache.spark.{HashPartitioner, SparkConf, SparkException} From b7c653d164536fc1954597451915221b39aae9f3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 23 Oct 2015 15:23:51 -0700 Subject: [PATCH 12/26] Added licenses, and fixed stuff --- .../streaming/StatefulNetworkWordCount.scala | 24 +++--- .../org/apache/spark/streaming/State.scala | 30 +++++++- .../spark/streaming/TrackStateSpec.scala | 17 ++++ .../dstream/PairDStreamFunctions.scala | 2 +- ...DStream.scala => TrackeStateDStream.scala} | 77 +++++++------------ .../spark/streaming/util/StateMap.scala | 22 +++++- .../streaming/BasicOperationsSuite.scala | 6 +- .../spark/streaming/StateMapSuite.scala | 20 ++++- 8 files changed, 121 insertions(+), 77 deletions(-) rename streaming/src/main/scala/org/apache/spark/streaming/dstream/{TrackedStateDStream.scala => TrackeStateDStream.scala} (80%) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 02ba1c2eed0f..50f637f9762f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -44,18 +44,6 @@ object StatefulNetworkWordCount { StreamingExamples.setStreamingLogLevels() - val updateFunc = (values: Seq[Int], state: Option[Int]) => { - val currentCount = values.sum - - val previousCount = state.getOrElse(0) - - Some(currentCount + previousCount) - } - - val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) - } - val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount") // Create the context with a 1 second batch size val ssc = new StreamingContext(sparkConf, Seconds(1)) @@ -72,8 +60,16 @@ object StatefulNetworkWordCount { // Update the cumulative count using updateStateByKey // This will give a Dstream made of state (which is the cumulative count of the words) - val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc, - new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD) + + val trackStateFunc = (word: String, one: Option[Int], state: State[Int]) => { + val sum = one.getOrElse(0) + state.getOrElse(0) + val output = (word, sum) + state.update(sum) + Some(output) + } + + val stateDstream = wordDstream.trackStateByKey( + TrackStateSpec(trackStateFunc).initialState(initialRDD)) stateDstream.print() ssc.start() ssc.awaitTermination() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index 0f660bd94060..d6ea2252c054 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.streaming /** @@ -32,9 +49,14 @@ sealed abstract class State[S] { /** Is the state going to be timed out by the system after this batch interval */ def isTimingOut(): Boolean + @inline final def getOption(): Option[S] = Option(get()) + /** Get the state if it exists, otherwise return the default value */ - @inline final def getOrElse[S1 >: S](default: => S1): S1 = - if (exists) default else this.get + @inline final def getOrElse[S1 >: S](default: => S1): S1 = { + if (exists) this.get else default + } + + @inline final override def toString() = getOption.map { _.toString }.getOrElse("") } /** Internal implementation of the [[State]] interface */ @@ -52,14 +74,14 @@ private[streaming] class StateImpl[S] extends State[S] { } def get(): S = { - null.asInstanceOf[S] + state } def update(newState: S): Unit = { require(!removed, "Cannot update the state after it has been removed") require(!timingOut, "Cannot update the state that is timing out") - updated = true state = newState + updated = true } def isTimingOut(): Boolean = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala index 86f6b669367c..f0edcf2b9bfe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.streaming import scala.reflect.ClassTag diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 77245923426b..bac511b89921 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -350,7 +350,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) } def trackStateByKey[S: ClassTag, T: ClassTag](spec: TrackStateSpec[K, V, S, T]): DStream[T] = { - new TrackedStateDStream[K, V, S, T]( + new TrackeStateDStream[K, V, S, T]( self, spec.asInstanceOf[TrackStateSpecImpl[K, V, S, T]] ).mapPartitions { partitionIter => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackedStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackeStateDStream.scala similarity index 80% rename from streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackedStateDStream.scala rename to streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackeStateDStream.scala index e4a6998480a9..205d6c70ef23 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackedStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackeStateDStream.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.streaming.dstream import java.io.{IOException, ObjectOutputStream} @@ -12,48 +29,6 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.util.StateMap import org.apache.spark.util.Utils - -// ================================================== -// ================================================== -// ================= PUBLIC CLASSES ================= -// ================================================== -// ================================================== - - - - - - - - - -// =============================================== -// =============================================== -// ============== PRIVATE CLASSES ================ -// =============================================== -// =============================================== - - - - - - - - -// ----------------------------------------------- -// --------------- StateMap stuff -------------- -// ----------------------------------------------- - - - - - - - -// ----------------------------------------------- -// --------------- StateRDD stuff -------------- -// ----------------------------------------------- - private[streaming] case class TrackStateRDDRecord[K: ClassTag, S: ClassTag, T: ClassTag]( stateMap: StateMap[K, S], emittedRecords: Seq[T]) @@ -118,15 +93,15 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: val newStateMap = prevStateRDDIterator.next().stateMap.copy() val emittedRecords = new ArrayBuffer[T] - val stateWrapper = new StateImpl[S]() + val wrappedState = new StateImpl[S]() dataIterator.foreach { case (key, value) => - stateWrapper.wrap(newStateMap.get(key)) - val emittedRecord = trackingFunction(key, Some(value), stateWrapper) - if (stateWrapper.isRemoved()) { + wrappedState.wrap(newStateMap.get(key)) + val emittedRecord = trackingFunction(key, Some(value), wrappedState) + if (wrappedState.isRemoved) { newStateMap.remove(key) - } else if (stateWrapper.isUpdated()) { - newStateMap.put(key, stateWrapper.get(), currentTime) + } else if (wrappedState.isUpdated) { + newStateMap.put(key, wrappedState.get(), currentTime) } emittedRecords ++= emittedRecord } @@ -134,8 +109,8 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: if (doFullScan) { if (timeoutThresholdTime.isDefined) { newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => - stateWrapper.wrapTiminoutState(state) - val emittedRecord = trackingFunction(key, None, stateWrapper) + wrappedState.wrapTiminoutState(state) + val emittedRecord = trackingFunction(key, None, wrappedState) emittedRecords ++= emittedRecord } } @@ -178,7 +153,7 @@ private[streaming] object TrackStateRDD { // ----------------------------------------------- -private[streaming] class TrackedStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( +private[streaming] class TrackeStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( parent: DStream[(K, V)], spec: TrackStateSpecImpl[K, V, S, T]) extends DStream[TrackStateRDDRecord[K, S, T]](parent.context) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index b46b3003200d..d19ed1b31eab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.streaming.util import java.io.{ObjectInputStream, ObjectOutputStream} @@ -14,7 +31,7 @@ private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Ser /** Get the state for a key if it exists */ def get(key: K): Option[S] - /** Get all the keys and states whose updated time is older than the give threshold time */ + /** Get all the keys and states whose updated time is older than the given threshold time */ def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] /** Get all the keys and states in this map. */ @@ -187,7 +204,8 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( // Write the parentStateMap while consolidating val doCompaction = shouldCompact val newParentSessionStore = if (doCompaction) { - new OpenHashMapBasedStateMap[K, S](initialCapacity = approxSize, deltaChainThreshold) + val initCapacity = if (approxSize > 0) approxSize else 64 + new OpenHashMapBasedStateMap[K, S](initialCapacity = initCapacity, deltaChainThreshold) } else { null } val iterOfActiveSessions = parentStateMap.getAll() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 1e5a997314ea..310cf2d46700 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -651,17 +651,17 @@ class BasicOperationsSuite extends TestSuiteBase { Seq() ) - val sessionOperation = (s: DStream[String]) => { + val trackStateOp = (s: DStream[String]) => { val updateFunc = (key: String, value: Option[Int], state: State[Int]) => { val sum = value.getOrElse(0) + state.getOrElse(0) val output = (key, sum) state.update(sum) Some(output) } - s.map(x => (x, 1)).trackStateByKey(TrackStateSpec.create(updateFunc)) + s.map(x => (x, 1)).trackStateByKey(TrackStateSpec(updateFunc)) } - testOperation(inputData, sessionOperation, outputData, true) + testOperation(inputData, trackStateOp, outputData, true) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index 604d031c93f7..bd1936e9ad9d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -1,10 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.streaming -import scala.reflect._ import scala.util.Random import org.apache.spark.SparkFunSuite -import org.apache.spark.streaming.dstream.{StateMap, OpenHashMapBasedStateMap} +import org.apache.spark.streaming.util.{OpenHashMapBasedStateMap, StateMap} import org.apache.spark.util.Utils class StateMapSuite extends SparkFunSuite { From be8cffc6ccddb14d993564669f380c300e83dc6b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 2 Nov 2015 04:00:42 -0800 Subject: [PATCH 13/26] Refactored code, fixed bugs, added unit tests --- .../org/apache/spark/streaming/State.scala | 1 + .../dstream/EmittedRecordsDStream.scala | 100 ++++++++ .../dstream/PairDStreamFunctions.scala | 14 +- .../TrackStateRDD.scala} | 145 +++++------ .../spark/streaming/util/StateMap.scala | 14 +- .../streaming/BasicOperationsSuite.scala | 37 +-- .../spark/streaming/StateMapSuite.scala | 114 ++++++++- .../streaming/TrackStateByKeySuite.scala | 229 ++++++++++++++++++ .../streaming/rdd/TrackStateRDDSuite.scala | 172 +++++++++++++ 9 files changed, 690 insertions(+), 136 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala rename streaming/src/main/scala/org/apache/spark/streaming/{dstream/TrackeStateDStream.scala => rdd/TrackStateRDD.scala} (50%) create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index d6ea2252c054..49c68a352745 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -90,6 +90,7 @@ private[streaming] class StateImpl[S] extends State[S] { def remove(): Unit = { require(!timingOut, "Cannot remove the state that is timing out") + require(!removed, "Cannot remove the state that has already been removed") removed = true } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala new file mode 100644 index 000000000000..192a99468dfa --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} + + + +abstract class EmittedRecordsDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + ssc: StreamingContext) extends DStream[T](ssc) { + + def stateSnapshots(): DStream[(K, S)] +} + + +private[streaming] class EmittedRecordsDStreamImpl[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + trackStateDStream: TrackStateDStream[K, V, S, T]) + extends EmittedRecordsDStream[K, V, S, T](trackStateDStream.context) { + + override def slideDuration: Duration = trackStateDStream.slideDuration + + override def dependencies: List[DStream[_]] = List(trackStateDStream) + + override def compute(validTime: Time): Option[RDD[T]] = { + trackStateDStream.getOrCompute(validTime).map { _.flatMap[T] { _.emittedRecords } } + } + + def stateSnapshots(): DStream[(K, S)] = { + trackStateDStream.flatMap[(K, S)] { _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable } + } +} + +/** + * A DStream that allows per-key state to be maintains, and arbitrary records to be generated + * based on updates to the state. + * + * @param parent Parent (key, value) stream that is the source + * @param spec Specifications of the trackStateByKey operation + * @tparam K Key type + * @tparam V Value type + * @tparam S Type of the state maintained + * @tparam T Type of the eiitted records + */ +private[streaming] class TrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + parent: DStream[(K, V)], spec: TrackStateSpecImpl[K, V, S, T]) + extends DStream[TrackStateRDDRecord[K, S, T]](parent.context) { + + persist(StorageLevel.MEMORY_ONLY) + + private val partitioner = spec.getPartitioner().getOrElse( + new HashPartitioner(ssc.sc.defaultParallelism)) + + private val trackingFunction = spec.getFunction() + + override def slideDuration: Duration = parent.slideDuration + + override def dependencies: List[DStream[_]] = List(parent) + + override val mustCheckpoint = true + + /** Method that generates a RDD for the given time */ + override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, T]]] = { + val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { + TrackStateRDD.createFromPairRDD[K, V, S, T]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, + validTime.milliseconds + ) + } + val newDataRDD = parent.getOrCompute(validTime).get + val partitionedDataRDD = newDataRDD.partitionBy(partitioner) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds + } + + Some(new TrackStateRDD(prevStateRDD, partitionedDataRDD, + trackingFunction, validTime.milliseconds, timeoutThresholdTime)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index bac511b89921..b2d307261661 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -349,13 +349,13 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) ) } - def trackStateByKey[S: ClassTag, T: ClassTag](spec: TrackStateSpec[K, V, S, T]): DStream[T] = { - new TrackeStateDStream[K, V, S, T]( - self, - spec.asInstanceOf[TrackStateSpecImpl[K, V, S, T]] - ).mapPartitions { partitionIter => - partitionIter.flatMap { _.emittedRecords } - } + def trackStateByKey[S: ClassTag, T: ClassTag](spec: TrackStateSpec[K, V, S, T]): EmittedRecordsDStream[K, V, S, T] = { + new EmittedRecordsDStreamImpl[K, V, S, T]( + new TrackStateDStream[K, V, S, T]( + self, + spec.asInstanceOf[TrackStateSpecImpl[K, V, S, T]] + ) + ) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackeStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala similarity index 50% rename from streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackeStateDStream.scala rename to streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index 205d6c70ef23..d5c779c534d3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackeStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -1,36 +1,42 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.dstream +package org.apache.spark.streaming.rdd -import java.io.{IOException, ObjectOutputStream} +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import org.apache.spark._ -import org.apache.spark.rdd.{EmptyRDD, RDD} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming._ -import org.apache.spark.streaming.util.StateMap +import org.apache.spark.rdd.{MapPartitionsRDD, RDD} +import org.apache.spark.streaming.{StateImpl, State} +import org.apache.spark.streaming.util.{EmptyStateMap, StateMap} import org.apache.spark.util.Utils +import org.apache.spark._ + + + + +private[streaming] case class TrackStateRDDRecord[K, S, T]( + var stateMap: StateMap[K, S], var emittedRecords: Seq[T]) { +/* + private def writeObject(outputStream: ObjectOutputStream): Unit = { + outputStream.writeObject(stateMap) + outputStream.writeInt(emittedRecords.size) + val iterator = emittedRecords.iterator + while(iterator.hasNext) { + outputStream.writeObject(iterator.next) + } + } -private[streaming] case class TrackStateRDDRecord[K: ClassTag, S: ClassTag, T: ClassTag]( - stateMap: StateMap[K, S], emittedRecords: Seq[T]) + private def readObject(inputStream: ObjectInputStream): Unit = { + stateMap = inputStream.readObject().asInstanceOf[StateMap[K, S]] + val numEmittedRecords = inputStream.readInt() + val array = new Array[T](numEmittedRecords) + var i = 0 + while(i < numEmittedRecords) { + array(i) = inputStream.readObject().asInstanceOf[T] + } + emittedRecords = array.toSeq + }*/ +} private[streaming] class TrackStateRDDPartition( @@ -38,8 +44,8 @@ private[streaming] class TrackStateRDDPartition( @transient private var prevStateRDD: RDD[_], @transient private var partitionedDataRDD: RDD[_]) extends Partition { - private[dstream] var previousSessionRDDPartition: Partition = null - private[dstream] var partitionedDataRDDPartition: Partition = null + private[rdd] var previousSessionRDDPartition: Partition = null + private[rdd] var partitionedDataRDDPartition: Partition = null override def index: Int = idx override def hashCode(): Int = idx @@ -53,14 +59,20 @@ private[streaming] class TrackStateRDDPartition( } } + + + +/** + * RDD storing the keyed-state of trackStateByKey and corresponding emitted records. + * Each partition of this RDD has a single record that contains a StateMap storing + */ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - _sc: SparkContext, private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]], private var partitionedDataRDD: RDD[(K, V)], trackingFunction: (K, Option[V], State[S]) => Option[T], currentTime: Long, timeoutThresholdTime: Option[Long] ) extends RDD[TrackStateRDDRecord[K, S, T]]( - _sc, + partitionedDataRDD.sparkContext, List( new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD), new OneToOneDependency(partitionedDataRDD)) @@ -68,7 +80,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: @volatile private var doFullScan = false - require(partitionedDataRDD.partitioner.nonEmpty) + require(prevStateRDD.partitioner.nonEmpty) require(partitionedDataRDD.partitioner == prevStateRDD.partitioner) override val partitioner = prevStateRDD.partitioner @@ -86,11 +98,13 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: stateRDDPartition.previousSessionRDDPartition, context) val dataIterator = partitionedDataRDD.iterator( stateRDDPartition.partitionedDataRDDPartition, context) - if (!prevStateRDDIterator.hasNext) { - throw new SparkException(s"Could not find state map in previous state RDD") + + val newStateMap = if (prevStateRDDIterator.hasNext) { + prevStateRDDIterator.next().stateMap.copy() + } else { + new EmptyStateMap[K, S]() } - val newStateMap = prevStateRDDIterator.next().stateMap.copy() val emittedRecords = new ArrayBuffer[T] val wrappedState = new StateImpl[S]() @@ -132,61 +146,32 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: } private[streaming] object TrackStateRDD { - def createFromPairRDD[K: ClassTag, S: ClassTag, T: ClassTag]( + + def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( pairRDD: RDD[(K, S)], partitioner: Partitioner, - updateTime: Long): RDD[TrackStateRDDRecord[K, S, T]] = { + updateTime: Long): TrackStateRDD[K, V, S, T] = { - val createRecord = (iterator: Iterator[(K, S)]) => { + val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => val stateMap = StateMap.create[K, S](SparkEnv.get.conf) iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime) } Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T])) - } - pairRDD.partitionBy(partitioner).mapPartitions[TrackStateRDDRecord[K, S, T]]( - createRecord, true) - } -} - - -// ----------------------------------------------- -// ---------------- SessionDStream --------------- -// ----------------------------------------------- - - -private[streaming] class TrackeStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - parent: DStream[(K, V)], spec: TrackStateSpecImpl[K, V, S, T]) - extends DStream[TrackStateRDDRecord[K, S, T]](parent.context) { + }, preservesPartitioning = true) - persist(StorageLevel.MEMORY_ONLY) + val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) - private val partitioner = spec.getPartitioner().getOrElse( - new HashPartitioner(ssc.sc.defaultParallelism)) + val noOpFunc = (key: K, value: Option[V], state: State[S]) => None - private val trackingFunction = spec.getFunction() - - override def slideDuration: Duration = parent.slideDuration - - override def dependencies: List[DStream[_]] = List(parent) - - override val mustCheckpoint = true - - /** Method that generates a RDD for the given time */ - override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, T]]] = { - val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { - TrackStateRDD.createFromPairRDD[K, S, T]( - spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), - partitioner, - validTime.milliseconds - ) - } - val newDataRDD = parent.getOrCompute(validTime).get - val partitionedDataRDD = newDataRDD.partitionBy(partitioner) - val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => - (validTime - interval).milliseconds - } + new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) + } +} - Some(new TrackStateRDD( - ssc.sparkContext, prevStateRDD, partitionedDataRDD, - trackingFunction, validTime.milliseconds, timeoutThresholdTime)) +private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) { + override protected def getPartitions: Array[Partition] = parent.partitions + override def compute(partition: Partition, context: TaskContext): Iterator[T] = { + parent.compute(partition, context).flatMap { _.emittedRecords } } } + +private[streaming] class StateSnapshotRDD[K: ClassTag, V: ClassTag] \ No newline at end of file diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index d19ed1b31eab..51e1c1831614 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -78,7 +78,7 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa /** Implementation of StateMap based on Spark's OpenHashMap */ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( - @transient @volatile private var parentStateMap: StateMap[K, S], + @transient @volatile var parentStateMap: StateMap[K, S], initialCapacity: Int = 64, deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD ) extends StateMap[K, S] { self => @@ -99,8 +99,12 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( /** Get the session data if it exists */ override def get(key: K): Option[S] = { val stateInfo = deltaMap(key) - if (stateInfo != null && !stateInfo.deleted) { - Some(stateInfo.data) + if (stateInfo != null) { + if (!stateInfo.deleted) { + Some(stateInfo.data) + } else { + None + } } else { parentStateMap.get(key) } @@ -185,6 +189,10 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, "\n" + tabs, "") } + override def toString(): String = { + s"[${System.identityHashCode(this)}, ${System.identityHashCode(parentStateMap)}]" + } + private def writeObject(outputStream: ObjectOutputStream): Unit = { outputStream.defaultWriteObject() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 3335bb1906eb..bd98351c8107 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.{TrackStateSpec, DStream, WindowedDStream} +import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} import org.apache.spark.util.{Clock, ManualClock} import org.apache.spark.{HashPartitioner, SparkConf, SparkException} @@ -709,41 +709,6 @@ class BasicOperationsSuite extends TestSuiteBase { } - test("trackStateByKey with emitted states") { - val inputData = - Seq( - Seq("a"), - Seq("a", "b"), - Seq("a", "b", "c"), - Seq("a", "b"), - Seq("a"), - Seq() - ) - - val outputData = - Seq( - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 3), ("b", 2), ("c", 1)), - Seq(("a", 4), ("b", 3)), - Seq(("a", 5)), - Seq() - ) - - val trackStateOp = (s: DStream[String]) => { - val updateFunc = (key: String, value: Option[Int], state: State[Int]) => { - val sum = value.getOrElse(0) + state.getOrElse(0) - val output = (key, sum) - state.update(sum) - Some(output) - } - s.map(x => (x, 1)).trackStateByKey(TrackStateSpec(updateFunc)) - } - - testOperation(inputData, trackStateOp, outputData, true) - } - - /** Test cleanup of RDDs in DStream metadata */ def runCleanupTest[T: ClassTag]( conf2: SparkConf, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index bd1936e9ad9d..afdd4514cd32 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming +import scala.collection.mutable import scala.util.Random import org.apache.spark.SparkFunSuite @@ -25,7 +26,7 @@ import org.apache.spark.util.Utils class StateMapSuite extends SparkFunSuite { - test("OpenHashMapBasedStateMap - basic operations") { + test("OpenHashMapBasedStateMap - put, get, getall, remove") { val map = new OpenHashMapBasedStateMap[Int, Int]() map.put(1, 100, 10) @@ -39,36 +40,128 @@ class StateMapSuite extends SparkFunSuite { assert(map.getAll().toSet === Set((2, 200, 20))) } - test("OpenHashMapBasedStateMap - basic operations after copy") { + test("OpenHashMapBasedStateMap - put, get, getall, remove after copy") { val parentMap = new OpenHashMapBasedStateMap[Int, Int]() parentMap.put(1, 100, 1) parentMap.put(2, 200, 2) parentMap.remove(1) + // Create child map and make changes val map = parentMap.copy() assert(map.getAll().toSet === Set((2, 200, 2))) + assert(map.get(1) === None) + assert(map.get(2) === Some(200)) // Add new items map.put(3, 300, 3) + assert(map.get(3) === Some(300)) map.put(4, 400, 4) + assert(map.get(4) === Some(400)) assert(map.getAll().toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4))) assert(parentMap.getAll().toSet === Set((2, 200, 2))) // Remove items - map.remove(4) // remove item added to this map - map.remove(2) // remove item remove in parent map + map.remove(4) + assert(map.get(4) === None) // item added in this map, then removed in this map + map.remove(2) + assert(map.get(2) === None) // item removed in parent map, then added in this map assert(map.getAll().toSet === Set((3, 300, 3))) assert(parentMap.getAll().toSet === Set((2, 200, 2))) // Update items - map.put(1, 1000, 100) // update item removed in parent map - map.put(2, 2000, 200) // update item added in parent map and removed in this map - map.put(3, 3000, 300) // update item added in this map - map.put(4, 4000, 400) // update item removed in this map + map.put(1, 1000, 100) + assert(map.get(1) === Some(1000)) // item removed in parent map, then added in this map + map.put(2, 2000, 200) + assert(map.get(2) === Some(2000)) // item added in parent map, then removed + added in this map + map.put(3, 3000, 300) + assert(map.get(3) === Some(3000)) // item added + updated in this map + map.put(4, 4000, 400) + assert(map.get(4) === Some(4000)) // item removed + updated in this map assert(map.getAll().toSet === Set((1, 1000, 100), (2, 2000, 200), (3, 3000, 300), (4, 4000, 400))) assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + map.remove(2) // remove item present in parent map, so that its not visible in child map + + // Create child map and see availability of items + val childMap = map.copy() + assert(childMap.getAll().toSet === map.getAll().toSet) + assert(childMap.get(1) === Some(1000)) // item removed in grandparent, but added in parent map + assert(childMap.get(2) === None) // item added in grandparent, but removed in parent map + assert(childMap.get(3) === Some(3000)) // item added and updated in parent map + + childMap.put(2, 20000, 200) + assert(childMap.get(2) === Some(20000)) // item map + } + + test("OpenHashMapBasedStateMap - all operation combo testing with copies ") { + val numTypeMapOps = 2 // 0 = put a new value, 1 = remove value + val numMapCopies = 4 // to test all combos of operations across 4 copies + val numOpsPerCopy = numTypeMapOps + val numTotalOps = numOpsPerCopy * numMapCopies + val numKeys = math.pow(numTypeMapOps, numTotalOps).toInt // to get all combinations of ops + + var stateMap: StateMap[Int, Int] = new OpenHashMapBasedStateMap[Int, Int]() + val refMap = new mutable.HashMap[Int, Int]() + + def assertMap(): Unit = { + assert(stateMap.getAll().map { x => (x._1, x._2) }.toSet === refMap.iterator.toSet) + for (keyId <- 0 until numKeys) { + assert(stateMap.get(keyId) === refMap.get(keyId)) + } + } + + /** + * Example: Operations combinations with 2 map copies + * + * ----------------------------------------------- + * | | Copy1 | Copy2 | + * | |-----------------|-----------------| + * | | Op1 Op2 | Op3 Op4 | + * | --------|-----------------|-----------------| + * | key 0 | put put | | put put | + * | key 1 | put put | | put rem | + * | key 2 | put put |c| rem put | + * | key 3 | put put |o| rem rem | + * | key 4 | put rem |p| put put | + * | key 5 | put rem |y| put rem | + * | key 6 | put rem | | rem put | + * | key 7 | put rem |t| rem rem | + * | key 8 | rem put |h| put put | + * | key 9 | rem put |e| put rem | + * | key 10 | rem put | | rem put | + * | key 11 | rem put |m| rem rem | + * | key 12 | rem rem |a| put put | + * | key 13 | rem rem |p| put rem | + * | key 14 | rem rem | | rem put | + * | key 15 | rem rem | | rem rem | + * ----------------------------------------------- + */ + + + for(opId <- 0 until numTotalOps) { + for (keyId <- 0 until numKeys) { + // Find the operation type that needs to be done + // This is similar to finding the nth bit value of a binary number + // E.g. nth bit from the right of any binary number B is [ B / (2 ^ (n - 1)) ] % 2 + val opCode = (keyId / math.pow(numTypeMapOps, numTotalOps - opId - 1).toInt) % numTypeMapOps + opCode match { + case 0 => + val value = Random.nextInt() + stateMap.put(keyId, value, value * 2) + refMap.put(keyId, value) + case 1 => + stateMap.remove(keyId) + refMap.remove(keyId) + } + } + if (opId % numOpsPerCopy == 0) { + assertMap() + stateMap = stateMap.copy() + } + } + assertMap() } test("OpenHashMapBasedStateMap - serializing and deserializing") { @@ -87,7 +180,8 @@ class StateMapSuite extends SparkFunSuite { // Do not test compaction assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) - val map3_ = Utils.deserialize[StateMap[Int, Int]](Utils.serialize(map3), Thread.currentThread().getContextClassLoader) + val map3_ = Utils.deserialize[StateMap[Int, Int]]( + Utils.serialize(map3), Thread.currentThread().getContextClassLoader) assert(map3_.getAll().toSet === map3.getAll().toSet) assert(map3.getAll().forall { case (key, state, _) => map3_.get(key) === Some(state)}) } @@ -99,6 +193,7 @@ class StateMapSuite extends SparkFunSuite { var map = new OpenHashMapBasedStateMap[Int, Int]( deltaChainThreshold = deltaChainThreshold) + // Make large delta chain with length more than deltaChainThreshold for(i <- 1 to targetDeltaLength) { map.put(Random.nextInt(), Random.nextInt(), Random.nextLong()) map = map.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] @@ -112,6 +207,5 @@ class StateMapSuite extends SparkFunSuite { assert(deser_map.shouldCompact === false) assert(deser_map.getAll().toSet === map.getAll().toSet) assert(map.getAll().forall { case (key, state, _) => deser_map.get(key) === Some(state)}) - } } \ No newline at end of file diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala new file mode 100644 index 000000000000..53ab0236b5a2 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -0,0 +1,229 @@ +package org.apache.spark.streaming + +import java.io.File + +import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.reflect.ClassTag + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.util.{Utils, ManualClock} +import org.apache.spark.{SparkFunSuite, SparkConf, SparkContext} +import org.apache.spark.streaming.dstream.{EmittedRecordsDStream, DStream} + +/** + * Created by tdas on 10/29/15. + */ +class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { + + private var sc: SparkContext = null + private var ssc: StreamingContext = null + private var checkpointDir: File = null + private val batchDuration = Seconds(1) + + before { + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + checkpointDir = Utils.createTempDir("checkpoint") + + ssc = new StreamingContext(sc, batchDuration) + ssc.checkpoint(checkpointDir.toString) + } + + after { + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + } + + override def beforeAll(): Unit = { + val conf = new SparkConf().setMaster("local").setAppName("TrackStateByKeySuite") + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + sc = new SparkContext(conf) + } + + test("basic operation") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq("aa"), + Seq("aa", "bb"), + Seq("aa", "bb", "cc"), + Seq("aa", "bb"), + Seq("aa"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + // state maintains running count, key string doubled and returned + val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOrElse(0) + state.update(sum) + Some(key * 2) + } + + testOperation(inputData, TrackStateSpec(trackStateFunc), outputData, stateData) + } + + test("states as emitted records") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3)), + Seq(("a", 5)), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOrElse(0) + val output = (key, sum) + state.update(sum) + Some(output) + } + + testOperation(inputData, TrackStateSpec(trackStateFunc), outputData, stateData) + } + + test("state removing") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + // States that were removed + val outputData = + Seq( + Seq(), + Seq(), + Seq("a"), + Seq("b"), + Seq("a", "c"), + Seq("b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("b", 1)), + Seq(("a", 1), ("c", 1)), + Seq(("b", 1)), + Seq(("a", 1)), + Seq(), + Seq() + ) + + val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { + if (state.exists) { + state.remove() + println(s"$key: state exists, removed state, and returning key") + Some(key) + } else { + state.update(value.get) + println(s"$key: State does not exists, saving state, and not returning anything") + None + } + } + + testOperation(inputData, TrackStateSpec(trackStateFunc).numPartitions(1), outputData, stateData) + } + + + private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( + input: Seq[Seq[K]], + trackStateSpec: TrackStateSpec[K, Int, S, T], + expectedOutputs: Seq[Seq[T]], + expectedStateSnapshots: Seq[Seq[(K, S)]] + ) { + + require(expectedOutputs.size == expectedStateSnapshots.size) + + // Setup the stream computation + val inputStream = new TestInputStream(ssc, input, numPartitions = 2) + val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec) + val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] + val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs) + val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]] + val stateSnapshotStream = new TestOutputStream( + trackeStateStream.stateSnapshots(), collectedStateSnapshots) + outputStream.register() + stateSnapshotStream.register() + + val batchCounter = new BatchCounter(ssc) + ssc.start() + + val numBatches = expectedOutputs.size + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds * numBatches) + + batchCounter.waitUntilBatchesCompleted(numBatches, 10000) + assert(expectedOutputs, collectedOutputs, "outputs") + assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots") + } + + private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) { + assert(expected.size === collected.size, + s"number of collected $typ (${collected.size}) different from expected (${expected.size})") + expected.zip(collected).foreach { case (c, e) => + assert(c.toSet === e.toSet, + s"collected $typ is different from expected" + + "\nExpected:\n" + expected.mkString("\n") + + "\nCollected:\n" + collected.mkString("\n") + ) + } + } + + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala new file mode 100644 index 000000000000..c37c2751b2d3 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -0,0 +1,172 @@ +package org.apache.spark.streaming.rdd + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.State +import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite} + +class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var sc = new SparkContext( + new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite")) + + override def afterAll(): Unit = { + sc.stop() + } + + test("creation from pair RDD") { + val data = Seq((1, "1"), (2, "2"), (3, "3")) + val partitioner = new HashPartitioner(10) + val rdd = TrackStateRDD.createFromPairRDD[Int, Int, String, Int]( + sc.parallelize(data), partitioner, 123) + assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty) + assert(rdd.partitions.size === partitioner.numPartitions) + + assert(rdd.partitioner === Some(partitioner)) + } + + test("updating state on existing TrackStateRDD") { + val initStates = Seq(("k1", 0), ("k2", 0)) + val initTime = 123 + val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet + val partitioner = new HashPartitioner(2) + val initStateRDD = TrackStateRDD.createFromPairRDD[String, Int, Int, Int]( + sc.parallelize(initStates), partitioner, initTime).persist() + assertRDD(initStateRDD, initStateWthTime, Set.empty) + + val updateTime = 345 + + /** + * Test that the test state RDD, when operated with new data, + * creates a new state RDD with expected states + */ + def testStateUpdates( + testStateRDD: TrackStateRDD[String, Int, Int, Int], + testData: Seq[(String, Int)], + expectedStates: Set[(String, Int, Int)]): TrackStateRDD[String, Int, Int, Int] = { + + // Persist the test TrackStateRDD so that its not recomputed while doing the next operation. + // This is to make sure that we only track which state keys are being touched in the next op. + testStateRDD.persist().count() + + // To track which keys are being touched + TrackStateRDDSuite.touchedStateKeys.clear() + + val trackingFunc = (key: String, data: Option[Int], state: State[Int]) => { + + // Track the key that has been touched + TrackStateRDDSuite.touchedStateKeys += key + + // If the data is 0, do not do anything with the state + // else if the data is 1, increment the state if it exists, or set new state to 0 + // else if the data is 2, remove the state if it exists + data match { + case Some(1) => + if (state.exists()) { state.update(state.get + 1) } + else state.update(0) + case Some(2) => + state.remove() + case _ => + } + None.asInstanceOf[Option[Int]] // Do not return anything, not being tested + } + val newDataRDD = sc.makeRDD(testData).partitionBy(testStateRDD.partitioner.get) + + // Assert that the new state RDD has expected state data + val newStateRDD = assertOperation( + testStateRDD, newDataRDD, trackingFunc, updateTime, expectedStates, Set.empty) + + // Assert that the function was called only for the keys present in the data + assert(TrackStateRDDSuite.touchedStateKeys.size === testData.size, + "More number of keys are being touched than that is expected") + assert(TrackStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys, + "Keys not in the data are being touched unexpectedly") + + // Assert that the test RDD's data has not changed + assertRDD(initStateRDD, initStateWthTime, Set.empty) + newStateRDD + } + + + // Test no-op, no state should change + testStateUpdates(initStateRDD, Seq(), initStateWthTime) // should not scan any state + testStateUpdates( + initStateRDD, Seq(("k1", 0)), initStateWthTime) // should not update existing state + testStateUpdates( + initStateRDD, Seq(("k3", 0)), initStateWthTime) // should not create new state + + // Test creation of new state + val rdd1 = testStateUpdates(initStateRDD, Seq(("k3", 1)), // should create k3's state as 0 + Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime))) + + val rdd2 = testStateUpdates(rdd1, Seq(("k4", 1)), // should create k4's state as 0 + Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime))) + + // Test updating of state + val rdd3 = testStateUpdates( + initStateRDD, Seq(("k1", 1)), // should increment k1's state 0 -> 1 + Set(("k1", 1, updateTime), ("k2", 0, initTime))) + + val rdd4 = testStateUpdates( + rdd3, Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)), // should update k2, 0 -> 2 and create k3, 0 + Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime))) + + val rdd5 = testStateUpdates( + rdd4, Seq(("k3", 1)), // should update k3's state 0 -> 2 + Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 1, updateTime))) + + // Test removing of state + val rdd6 = testStateUpdates( // should remove k1's state + initStateRDD, Seq(("k1", 2)), Set(("k2", 0, initTime))) + + val rdd7 = testStateUpdates( // should remove k2's state + rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime))) + + val rdd8 = testStateUpdates( + rdd7, Seq(("k3", 2)), Set() // + ) + } + + private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + testStateRDD: TrackStateRDD[K, V, S, T], + newDataRDD: RDD[(K, V)], + trackStateFunc: (K, Option[V], State[S]) => Option[T], + currentTime: Long, + expectedStates: Set[(K, S, Int)], + expectedEmittedRecords: Set[T]): TrackStateRDD[K, V, S, T] = { + + val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) { + newDataRDD.partitionBy(testStateRDD.partitioner.get) + } else { + newDataRDD + } + + val newStateRDD = new TrackStateRDD[K, V, S, T]( + testStateRDD, newDataRDD, trackStateFunc, currentTime, None) + + // Persist to make sure that it gets computed only once and we can track precisely how many + // state keys the computing touched + newStateRDD.persist() + assertRDD(newStateRDD, expectedStates, expectedEmittedRecords) + newStateRDD + } + + private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + trackStateRDD: TrackStateRDD[K, V, S, T], + expectedStates: Set[(K, S, Int)], + expectedEmittedRecords: Set[T]): Unit = { + val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet + val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet + assert(states === expectedStates, "states after track state operation were not as expected") + assert(emittedRecords === expectedEmittedRecords, + "emitted records after track state operation were not as expected") + } +} + +object TrackStateRDDSuite { + private val touchedStateKeys = new ArrayBuffer[String]() +} \ No newline at end of file From 6c02f4479bbc8822382a74f5dab2caf557ad5ae3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 2 Nov 2015 04:11:45 -0800 Subject: [PATCH 14/26] Fixing few things --- .../dstream/EmittedRecordsDStream.scala | 1 - .../spark/streaming/rdd/TrackStateRDD.scala | 24 +++++++++++++------ .../spark/streaming/util/StateMap.scala | 2 +- .../streaming/BasicOperationsSuite.scala | 5 ++-- .../spark/streaming/StateMapSuite.scala | 9 ++++--- .../streaming/TrackStateByKeySuite.scala | 22 +++++++++++++---- 6 files changed, 42 insertions(+), 21 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala index 192a99468dfa..55ce8689f101 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala @@ -26,7 +26,6 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} - abstract class EmittedRecordsDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( ssc: StreamingContext) extends DStream[T](ssc) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index d5c779c534d3..b46fa34dd861 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.streaming.rdd import java.io.{IOException, ObjectInputStream, ObjectOutputStream} @@ -11,9 +28,6 @@ import org.apache.spark.streaming.util.{EmptyStateMap, StateMap} import org.apache.spark.util.Utils import org.apache.spark._ - - - private[streaming] case class TrackStateRDDRecord[K, S, T]( var stateMap: StateMap[K, S], var emittedRecords: Seq[T]) { /* @@ -60,8 +74,6 @@ private[streaming] class TrackStateRDDPartition( } - - /** * RDD storing the keyed-state of trackStateByKey and corresponding emitted records. * Each partition of this RDD has a single record that contains a StateMap storing @@ -173,5 +185,3 @@ private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag parent.compute(partition, context).flatMap { _.emittedRecords } } } - -private[streaming] class StateSnapshotRDD[K: ClassTag, V: ClassTag] \ No newline at end of file diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 51e1c1831614..67483f00a2f8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -299,4 +299,4 @@ private[streaming] object OpenHashMapBasedStateMap { class LimitMarker(val num: Int) extends Serializable val DELTA_CHAIN_LENGTH_THRESHOLD = 20 -} \ No newline at end of file +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index bd98351c8107..9d296c6d3ef8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -22,11 +22,13 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.language.existentials import scala.reflect.ClassTag +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} import org.apache.spark.util.{Clock, ManualClock} -import org.apache.spark.{HashPartitioner, SparkConf, SparkException} +import org.apache.spark.HashPartitioner class BasicOperationsSuite extends TestSuiteBase { test("map") { @@ -708,7 +710,6 @@ class BasicOperationsSuite extends TestSuiteBase { } } - /** Test cleanup of RDDs in DStream metadata */ def runCleanupTest[T: ClassTag]( conf2: SparkConf, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index afdd4514cd32..f927d7e49165 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -115,11 +115,11 @@ class StateMapSuite extends SparkFunSuite { /** * Example: Operations combinations with 2 map copies * - * ----------------------------------------------- + * _______________________________________________ * | | Copy1 | Copy2 | * | |-----------------|-----------------| * | | Op1 Op2 | Op3 Op4 | - * | --------|-----------------|-----------------| + * |---------|-----------------|-----------------| * | key 0 | put put | | put put | * | key 1 | put put | | put rem | * | key 2 | put put |c| rem put | @@ -136,10 +136,9 @@ class StateMapSuite extends SparkFunSuite { * | key 13 | rem rem |p| put rem | * | key 14 | rem rem | | rem put | * | key 15 | rem rem | | rem rem | - * ----------------------------------------------- + * |_________|_________________|_________________| */ - for(opId <- 0 until numTotalOps) { for (keyId <- 0 until numKeys) { // Find the operation type that needs to be done @@ -208,4 +207,4 @@ class StateMapSuite extends SparkFunSuite { assert(deser_map.getAll().toSet === map.getAll().toSet) assert(map.getAll().forall { case (key, state, _) => deser_map.get(key) === Some(state)}) } -} \ No newline at end of file +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index 53ab0236b5a2..07831cc297d6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.streaming import java.io.File @@ -11,9 +28,6 @@ import org.apache.spark.util.{Utils, ManualClock} import org.apache.spark.{SparkFunSuite, SparkConf, SparkContext} import org.apache.spark.streaming.dstream.{EmittedRecordsDStream, DStream} -/** - * Created by tdas on 10/29/15. - */ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { private var sc: SparkContext = null @@ -224,6 +238,4 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef ) } } - - } From 23596b8dc368d3d5c3e82f3036f06fd9eefa8c2e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 3 Nov 2015 01:13:34 -0800 Subject: [PATCH 15/26] Fixed timeout bug, refactored tests, added tests --- .../spark/streaming/rdd/TrackStateRDD.scala | 1 + .../spark/streaming/util/StateMap.scala | 2 +- .../spark/streaming/StateMapSuite.scala | 262 ++++++++++++------ .../streaming/TrackStateByKeySuite.scala | 64 ++++- 4 files changed, 239 insertions(+), 90 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index b46fa34dd861..db793e47e751 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -138,6 +138,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: wrappedState.wrapTiminoutState(state) val emittedRecord = trackingFunction(key, None, wrappedState) emittedRecords ++= emittedRecord + newStateMap.remove(key) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 67483f00a2f8..cdd85452d30b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -68,9 +68,9 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa override def put(key: K, session: S, updateTime: Long): Unit = ??? override def get(key: K): Option[S] = None override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty + override def getAll(): Iterator[(K, S, Long)] = Iterator.empty override def copy(): StateMap[K, S] = new EmptyStateMap[K, S] override def remove(key: K): Unit = { } - override def getAll(): Iterator[(K, S, Long)] = Iterator.empty override def toDebugString(): String = "" } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index f927d7e49165..bb4f8481a980 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -17,22 +17,43 @@ package org.apache.spark.streaming -import scala.collection.mutable +import scala.collection.{immutable, mutable, Map} import scala.util.Random import org.apache.spark.SparkFunSuite -import org.apache.spark.streaming.util.{OpenHashMapBasedStateMap, StateMap} +import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap} import org.apache.spark.util.Utils class StateMapSuite extends SparkFunSuite { - test("OpenHashMapBasedStateMap - put, get, getall, remove") { + test("EmptyStateMap") { + val map = new EmptyStateMap[Int, Int] + intercept[scala.NotImplementedError] { + map.put(1, 1, 1) + } + assert(map.get(1) === None) + assert(map.getByTime(10000).isEmpty) + assert(map.getAll().isEmpty) + map.remove(1) // no exception + assert(map.copy().getAll().isEmpty) + } + + test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove") { val map = new OpenHashMapBasedStateMap[Int, Int]() map.put(1, 100, 10) assert(map.get(1) === Some(100)) assert(map.get(2) === None) + assert(map.getByTime(11).toSet === Set((1, 100, 10))) + assert(map.getByTime(10).toSet === Set.empty) + assert(map.getByTime(9).toSet === Set.empty) + assert(map.getAll().toSet === Set((1, 100, 10))) + map.put(2, 200, 20) + assert(map.getByTime(21).toSet === Set((1, 100, 10), (2, 200, 20))) + assert(map.getByTime(11).toSet === Set((1, 100, 10))) + assert(map.getByTime(10).toSet === Set.empty) + assert(map.getByTime(9).toSet === Set.empty) assert(map.getAll().toSet === Set((1, 100, 10), (2, 200, 20))) map.remove(1) @@ -40,7 +61,7 @@ class StateMapSuite extends SparkFunSuite { assert(map.getAll().toSet === Set((2, 200, 20))) } - test("OpenHashMapBasedStateMap - put, get, getall, remove after copy") { + test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove with copy") { val parentMap = new OpenHashMapBasedStateMap[Int, Int]() parentMap.put(1, 100, 1) parentMap.put(2, 200, 2) @@ -48,15 +69,19 @@ class StateMapSuite extends SparkFunSuite { // Create child map and make changes val map = parentMap.copy() - assert(map.getAll().toSet === Set((2, 200, 2))) assert(map.get(1) === None) assert(map.get(2) === Some(200)) + assert(map.getByTime(10).toSet === Set((2, 200, 2))) + assert(map.getByTime(2).toSet === Set.empty) + assert(map.getAll().toSet === Set((2, 200, 2))) // Add new items map.put(3, 300, 3) assert(map.get(3) === Some(300)) map.put(4, 400, 4) assert(map.get(4) === Some(400)) + assert(map.getByTime(10).toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4))) + assert(map.getByTime(4).toSet === Set((2, 200, 2), (3, 300, 3))) assert(map.getAll().toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4))) assert(parentMap.getAll().toSet === Set((2, 200, 2))) @@ -95,74 +120,6 @@ class StateMapSuite extends SparkFunSuite { assert(childMap.get(2) === Some(20000)) // item map } - test("OpenHashMapBasedStateMap - all operation combo testing with copies ") { - val numTypeMapOps = 2 // 0 = put a new value, 1 = remove value - val numMapCopies = 4 // to test all combos of operations across 4 copies - val numOpsPerCopy = numTypeMapOps - val numTotalOps = numOpsPerCopy * numMapCopies - val numKeys = math.pow(numTypeMapOps, numTotalOps).toInt // to get all combinations of ops - - var stateMap: StateMap[Int, Int] = new OpenHashMapBasedStateMap[Int, Int]() - val refMap = new mutable.HashMap[Int, Int]() - - def assertMap(): Unit = { - assert(stateMap.getAll().map { x => (x._1, x._2) }.toSet === refMap.iterator.toSet) - for (keyId <- 0 until numKeys) { - assert(stateMap.get(keyId) === refMap.get(keyId)) - } - } - - /** - * Example: Operations combinations with 2 map copies - * - * _______________________________________________ - * | | Copy1 | Copy2 | - * | |-----------------|-----------------| - * | | Op1 Op2 | Op3 Op4 | - * |---------|-----------------|-----------------| - * | key 0 | put put | | put put | - * | key 1 | put put | | put rem | - * | key 2 | put put |c| rem put | - * | key 3 | put put |o| rem rem | - * | key 4 | put rem |p| put put | - * | key 5 | put rem |y| put rem | - * | key 6 | put rem | | rem put | - * | key 7 | put rem |t| rem rem | - * | key 8 | rem put |h| put put | - * | key 9 | rem put |e| put rem | - * | key 10 | rem put | | rem put | - * | key 11 | rem put |m| rem rem | - * | key 12 | rem rem |a| put put | - * | key 13 | rem rem |p| put rem | - * | key 14 | rem rem | | rem put | - * | key 15 | rem rem | | rem rem | - * |_________|_________________|_________________| - */ - - for(opId <- 0 until numTotalOps) { - for (keyId <- 0 until numKeys) { - // Find the operation type that needs to be done - // This is similar to finding the nth bit value of a binary number - // E.g. nth bit from the right of any binary number B is [ B / (2 ^ (n - 1)) ] % 2 - val opCode = (keyId / math.pow(numTypeMapOps, numTotalOps - opId - 1).toInt) % numTypeMapOps - opCode match { - case 0 => - val value = Random.nextInt() - stateMap.put(keyId, value, value * 2) - refMap.put(keyId, value) - case 1 => - stateMap.remove(keyId) - refMap.remove(keyId) - } - } - if (opId % numOpsPerCopy == 0) { - assertMap() - stateMap = stateMap.copy() - } - } - assertMap() - } - test("OpenHashMapBasedStateMap - serializing and deserializing") { val map1 = new OpenHashMapBasedStateMap[Int, Int]() map1.put(1, 100, 1) @@ -179,10 +136,9 @@ class StateMapSuite extends SparkFunSuite { // Do not test compaction assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) - val map3_ = Utils.deserialize[StateMap[Int, Int]]( + val deser_map3 = Utils.deserialize[StateMap[Int, Int]]( Utils.serialize(map3), Thread.currentThread().getContextClassLoader) - assert(map3_.getAll().toSet === map3.getAll().toSet) - assert(map3.getAll().forall { case (key, state, _) => map3_.get(key) === Some(state)}) + assertMap(deser_map3, map3, 1, "Deserialized map not same as original map") } test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") { @@ -194,7 +150,7 @@ class StateMapSuite extends SparkFunSuite { // Make large delta chain with length more than deltaChainThreshold for(i <- 1 to targetDeltaLength) { - map.put(Random.nextInt(), Random.nextInt(), Random.nextLong()) + map.put(Random.nextInt(), Random.nextInt(), 1) map = map.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] } assert(map.deltaChainLength > deltaChainThreshold) @@ -204,7 +160,155 @@ class StateMapSuite extends SparkFunSuite { Utils.serialize(map), Thread.currentThread().getContextClassLoader) assert(deser_map.deltaChainLength < deltaChainThreshold) assert(deser_map.shouldCompact === false) - assert(deser_map.getAll().toSet === map.getAll().toSet) - assert(map.getAll().forall { case (key, state, _) => deser_map.get(key) === Some(state)}) + assertMap(deser_map, map, 1, "Deserialized + compacted map not same as original map") + } + + test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") { + /* + * This tests the map using all permutations of sequences operations, across multiple map + * copies as well as between copies. It is to ensure complete coverage, though it is + * kind of hard to debug this. It is set up as follows. + * + * - For any key, there can be 2 types of update ops on a state map - put or remove + * + * - These operations are done on a test map in "sets". After each set, the map is "copied" + * to create a new map, and the next set of operations are done on the new one. This tests + * whether the map data persistes correctly across copies. + * + * - Within each set, there are a number of operations to test whether the map correctly + * updates and removes data without affecting the parent state map. + * + * - Overall this creates (numSets * numOpsPerSet) operations, each of which that can 2 types + * of operations. This leads to a total of [2 ^ (numSets * numOpsPerSet)] different sequence + * of operations, which we will test with different keys. + * + * Example: Operation combinations with numSets = 2, and numOpsPerSet = 2 give 4 operations, + * 2 ^ 4 = 16 possible permutations, tested using 16 keys. + * _______________________________________________ + * | | Set1 | Set2 | + * | |-----------------|-----------------| + * | | Op1 Op2 |c| Op3 Op4 | + * |---------|----------------|o|----------------| + * | key 0 | put put |p| put put | + * | key 1 | put put |y| put rem | + * | key 2 | put put | | rem put | + * | key 3 | put put |t| rem rem | + * | key 4 | put rem |h| put put | + * | key 5 | put rem |e| put rem | + * | key 6 | put rem | | rem put | + * | key 7 | put rem |s| rem rem | + * | key 8 | rem put |t| put put | + * | key 9 | rem put |a| put rem | + * | key 10 | rem put |t| rem put | + * | key 11 | rem put |e| rem rem | + * | key 12 | rem rem | | put put | + * | key 13 | rem rem |m| put rem | + * | key 14 | rem rem |a| rem put | + * | key 15 | rem rem |p| rem rem | + * |_________|________________|_|________________| + */ + + val numTypeMapOps = 2 // 0 = put a new value, 1 = remove value + val numSets = 3 + val numOpsPerSet = 3 // to test seq of ops like update -> remove -> update in same set + val numTotalOps = numOpsPerSet * numSets + val numKeys = math.pow(numTypeMapOps, numTotalOps).toInt // to get all combinations of ops + + val refMap = new mutable.HashMap[Int, (Int, Long)]() + var prevSetRefMap: immutable.Map[Int, (Int, Long)] = null + + var stateMap: StateMap[Int, Int] = new OpenHashMapBasedStateMap[Int, Int]() + var prevSetStateMap: StateMap[Int, Int] = null + + var time = 1L + + for (setId <- 0 until numSets) { + for(opInSetId <- 0 until numOpsPerSet) { + val opId = setId * numOpsPerSet + opInSetId + for (keyId <- 0 until numKeys) { + time += 1 + // Find the operation type that needs to be done + // This is similar to finding the nth bit value of a binary number + // E.g. nth bit from the right of any binary number B is [ B / (2 ^ (n - 1)) ] % 2 + val opCode = + (keyId / math.pow(numTypeMapOps, numTotalOps - opId - 1).toInt) % numTypeMapOps + opCode match { + case 0 => + val value = Random.nextInt() + stateMap.put(keyId, value, time) + refMap.put(keyId, (value, time)) + case 1 => + stateMap.remove(keyId) + refMap.remove(keyId) + } + } + + // Test whether the current state map after all key updates is correct + assertMap(stateMap, refMap, time, "State map does not match reference map") + + // Test whether the previous map before copy has not changed + if (prevSetStateMap != null && prevSetRefMap != null) { + assertMap(prevSetStateMap, prevSetRefMap, time, + "Parent state map somehow got modified, does not match corresponding reference map") + } + } + + // Copy the map and remember the previous maps for future tests + prevSetStateMap = stateMap + prevSetRefMap = refMap.toMap + stateMap = stateMap.copy() + + // Assert that the copied map has the same data + assertMap(stateMap, prevSetRefMap, time, + "State map does not match reference map after copying") + } + assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") + } + + // Assert whether all the data and operations on a state map matches that of a reference state map + private def assertMap( + mapToTest: StateMap[Int, Int], + refMapToTestWith: StateMap[Int, Int], + time: Long, + msg: String): Unit = { + withClue(msg) { + // Assert all the data is same as the reference map + assert(mapToTest.getAll().toSet === refMapToTestWith.getAll().toSet) + + // Assert that get on every key returns the right value + for (keyId <- refMapToTestWith.getAll().map { _._1 }) { + assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId)) + } + + // Assert that every time threshold returns the correct data + for (t <- 0L to (time + 1)) { + assert(mapToTest.getByTime(t).toSet === refMapToTestWith.getByTime(t).toSet) + } + } + } + + // Assert whether all the data and operations on a state map matches that of a reference map + private def assertMap( + mapToTest: StateMap[Int, Int], + refMapToTestWith: Map[Int, (Int, Long)], + time: Long, + msg: String): Unit = { + withClue(msg) { + // Assert all the data is same as the reference map + assert(mapToTest.getAll().toSet === + refMapToTestWith.iterator.map { x => (x._1, x._2._1, x._2._2) }.toSet) + + // Assert that get on every key returns the right value + for (keyId <- refMapToTestWith.keys) { + assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId).map { _._1 }) + } + + // Assert that every time threshold returns the correct data + for (t <- 0L to (time + 1)) { + val expectedRecords = + refMapToTestWith.iterator.filter { _._2._2 < t }.map { x => (x._1, x._2._1, x._2._2) } + assert(mapToTest.getByTime(t).toSet === expectedRecords.toSet) + } + } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index 07831cc297d6..bde2a93336bb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -146,11 +146,11 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef Seq( Seq(), Seq("a"), - Seq("a", "b"), - Seq("a", "b", "c"), - Seq("a", "b", "c"), - Seq("a", "b"), - Seq("a"), + Seq("a", "b"), // a will be removed + Seq("a", "b", "c"), // b will be removed + Seq("a", "b", "c"), // a and c will be removed + Seq("a", "b"), // b will be removed + Seq("a"), // a will be removed Seq() ) @@ -194,16 +194,62 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef testOperation(inputData, TrackStateSpec(trackStateFunc).numPartitions(1), outputData, stateData) } + test("state timing out") { + val inputData = + Seq( + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq(), // c will time out + Seq(), // b will time out + Seq("a") // a will not time out + ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active + + val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { + if (value.isDefined) { + state.update(1) + } + if (state.isTimingOut) { + Some(key) + } else { + None + } + } + + val (collectedOutputs, collectedStateSnapshots) = getOperationOutput( + inputData, TrackStateSpec(trackStateFunc).timeout(Seconds(3)), 20) + + // b and c should be emitted once each, when they were marked as expired + assert(collectedOutputs.flatten.sorted === Seq("b", "c")) + + // States for a, b, c should be defined at one point of time + assert(collectedStateSnapshots.exists { _.toSet == Set(("a", 1), ("b", 1), ("c", 1)) }) + + // Finally state should be defined only for a + assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) + } + private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], trackStateSpec: TrackStateSpec[K, Int, S, T], expectedOutputs: Seq[Seq[T]], expectedStateSnapshots: Seq[Seq[(K, S)]] - ) { - + ): Unit = { require(expectedOutputs.size == expectedStateSnapshots.size) + val (collectedOutputs, collectedStateSnapshots) = + getOperationOutput(input, trackStateSpec, expectedOutputs.size) + assert(expectedOutputs, collectedOutputs, "outputs") + assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots") + } + + private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag]( + input: Seq[Seq[K]], + trackStateSpec: TrackStateSpec[K, Int, S, T], + numBatches: Int + ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { + // Setup the stream computation val inputStream = new TestInputStream(ssc, input, numPartitions = 2) val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec) @@ -218,13 +264,11 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef val batchCounter = new BatchCounter(ssc) ssc.start() - val numBatches = expectedOutputs.size val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] clock.advance(batchDuration.milliseconds * numBatches) batchCounter.waitUntilBatchesCompleted(numBatches, 10000) - assert(expectedOutputs, collectedOutputs, "outputs") - assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots") + (collectedOutputs, collectedStateSnapshots) } private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) { From df927bada90b58e75616b55564adc37eb83292f4 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 3 Nov 2015 01:19:43 -0800 Subject: [PATCH 16/26] Added license --- .../streaming/rdd/TrackStateRDDSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index c37c2751b2d3..2bcf8c516bb6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.streaming.rdd import scala.collection.mutable.ArrayBuffer From 6a75966cf863566008c6c03ad7fe7924e3a25db9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 3 Nov 2015 02:17:34 -0800 Subject: [PATCH 17/26] Scala style fixes --- .../src/main/scala/org/apache/spark/streaming/State.scala | 8 +++++--- .../scala/org/apache/spark/streaming/TrackStateSpec.scala | 2 +- .../spark/streaming/dstream/EmittedRecordsDStream.scala | 6 ++++-- .../spark/streaming/dstream/PairDStreamFunctions.scala | 4 +++- .../org/apache/spark/streaming/rdd/TrackStateRDD.scala | 4 ++-- .../scala/org/apache/spark/streaming/util/StateMap.scala | 6 ++++-- .../scala/org/apache/spark/streaming/StateMapSuite.scala | 2 +- .../org/apache/spark/streaming/TrackStateByKeySuite.scala | 2 -- .../apache/spark/streaming/rdd/TrackStateRDDSuite.scala | 6 +++--- 9 files changed, 23 insertions(+), 17 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index 49c68a352745..061eb57c8c79 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -26,10 +26,10 @@ package org.apache.spark.streaming * }}} */ sealed abstract class State[S] { - + /** Whether the state already exists */ def exists(): Boolean - + /** * Get the state if it exists, otherwise wise it will throw an exception. * Check with `exists()` whether the state exists or not before calling `get()`. @@ -56,7 +56,9 @@ sealed abstract class State[S] { if (exists) this.get else default } - @inline final override def toString() = getOption.map { _.toString }.getOrElse("") + @inline final override def toString(): String = { + getOption.map { _.toString }.getOrElse("") + } } /** Internal implementation of the [[State]] interface */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala index f0edcf2b9bfe..4d639a8f85a3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala @@ -108,4 +108,4 @@ case class TrackStateSpecImpl[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner) private[streaming] def getTimeoutInterval(): Option[Duration] = Option(timeoutInterval) -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala index 55ce8689f101..de357ea5664a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala @@ -33,7 +33,8 @@ abstract class EmittedRecordsDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: C } -private[streaming] class EmittedRecordsDStreamImpl[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( +private[streaming] class EmittedRecordsDStreamImpl[ + K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( trackStateDStream: TrackStateDStream[K, V, S, T]) extends EmittedRecordsDStream[K, V, S, T](trackStateDStream.context) { @@ -46,7 +47,8 @@ private[streaming] class EmittedRecordsDStreamImpl[K: ClassTag, V: ClassTag, S: } def stateSnapshots(): DStream[(K, S)] = { - trackStateDStream.flatMap[(K, S)] { _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable } + trackStateDStream.flatMap[(K, S)] { + _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index b2d307261661..cdc668a4ef8f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -349,7 +349,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) ) } - def trackStateByKey[S: ClassTag, T: ClassTag](spec: TrackStateSpec[K, V, S, T]): EmittedRecordsDStream[K, V, S, T] = { + /** TODO: Add scala docs */ + def trackStateByKey[S: ClassTag, T: ClassTag]( + spec: TrackStateSpec[K, V, S, T]): EmittedRecordsDStream[K, V, S, T] = { new EmittedRecordsDStreamImpl[K, V, S, T]( new TrackStateDStream[K, V, S, T]( self, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index db793e47e751..3183ebce14d0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -30,7 +30,7 @@ import org.apache.spark._ private[streaming] case class TrackStateRDDRecord[K, S, T]( var stateMap: StateMap[K, S], var emittedRecords: Seq[T]) { -/* + /* private def writeObject(outputStream: ObjectOutputStream): Unit = { outputStream.writeObject(stateMap) outputStream.writeInt(emittedRecords.size) @@ -165,7 +165,7 @@ private[streaming] object TrackStateRDD { partitioner: Partitioner, updateTime: Long): TrackStateRDD[K, V, S, T] = { - val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => + val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => val stateMap = StateMap.create[K, S](SparkEnv.get.conf) iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime) } Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T])) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index cdd85452d30b..b6b629a9000b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -65,7 +65,9 @@ private[streaming] object StateMap { /** Specific implementation of SessionStore interface representing an empty map */ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { - override def put(key: K, session: S, updateTime: Long): Unit = ??? + override def put(key: K, session: S, updateTime: Long): Unit = { + throw new NotImplementedError("put() should not be called on an EmptyStateMap") + } override def get(key: K): Option[S] = None override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty override def getAll(): Iterator[(K, S, Long)] = Iterator.empty @@ -184,7 +186,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( override def toDebugString(): String = { val tabs = if (deltaChainLength > 0) { - (" " * (deltaChainLength - 1)) +"+--- " + (" " * (deltaChainLength - 1)) + "+--- " } else "" parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, "\n" + tabs, "") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index bb4f8481a980..d50ecdee498d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -223,7 +223,7 @@ class StateMapSuite extends SparkFunSuite { var time = 1L for (setId <- 0 until numSets) { - for(opInSetId <- 0 until numOpsPerSet) { + for (opInSetId <- 0 until numOpsPerSet) { val opId = setId * numOpsPerSet + opInSetId for (keyId <- 0 until numKeys) { time += 1 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index bde2a93336bb..aa2fed71a215 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -182,11 +182,9 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { if (state.exists) { state.remove() - println(s"$key: state exists, removed state, and returning key") Some(key) } else { state.update(value.get) - println(s"$key: State does not exists, saving state, and not returning anything") None } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index 2bcf8c516bb6..c309ff7cc8d8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -128,8 +128,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { initStateRDD, Seq(("k1", 1)), // should increment k1's state 0 -> 1 Set(("k1", 1, updateTime), ("k2", 0, initTime))) - val rdd4 = testStateUpdates( - rdd3, Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)), // should update k2, 0 -> 2 and create k3, 0 + val rdd4 = testStateUpdates(rdd3, + Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)), // should update k2, 0 -> 2 and create k3, 0 Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime))) val rdd5 = testStateUpdates( @@ -186,4 +186,4 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { object TrackStateRDDSuite { private val touchedStateKeys = new ArrayBuffer[String]() -} \ No newline at end of file +} From 0f1b1bc9d71921b8a526bbf88994b6585861f0ec Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 3 Nov 2015 02:39:55 -0800 Subject: [PATCH 18/26] Added tests for initial state RDDs --- .../spark/streaming/StateMapSuite.scala | 4 +- .../streaming/TrackStateByKeySuite.scala | 41 ++++++++++++++++++- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index d50ecdee498d..15b05d8ed259 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -182,8 +182,8 @@ class StateMapSuite extends SparkFunSuite { * of operations. This leads to a total of [2 ^ (numSets * numOpsPerSet)] different sequence * of operations, which we will test with different keys. * - * Example: Operation combinations with numSets = 2, and numOpsPerSet = 2 give 4 operations, - * 2 ^ 4 = 16 possible permutations, tested using 16 keys. + * Example: With numSets = 2, and numOpsPerSet = 2 give numTotalOps = 4. This means that + * 2 ^ 4 = 16 possible permutations needs to be tested using 16 keys. * _______________________________________________ * | | Set1 | Set2 | * | |-----------------|-----------------| diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index aa2fed71a215..ea9284484f5c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -141,6 +141,45 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef testOperation(inputData, TrackStateSpec(trackStateFunc), outputData, stateData) } + test("initial states, with nothing emitted") { + + val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)) + + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = Seq.fill(inputData.size)(Seq.empty[Int]) + + val stateData = + Seq( + Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)), + Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)), + Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)), + Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)), + Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)), + Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)), + Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)) + ) + + val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOrElse(0) + val output = (key, sum) + state.update(sum) + None.asInstanceOf[Option[Int]] + } + + val trackStateSpec = TrackStateSpec(trackStateFunc).initialState(sc.makeRDD(initialState)) + testOperation(inputData, trackStateSpec, outputData, stateData) + } + test("state removing") { val inputData = Seq( @@ -280,4 +319,4 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef ) } } -} + From 62d9abdf76a386cc8e08c4094b4c5d91def7a432 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 6 Nov 2015 06:49:48 -0800 Subject: [PATCH 19/26] Updated docs, renamed clases --- .../streaming/StatefulNetworkWordCount.scala | 2 +- .../org/apache/spark/streaming/State.scala | 86 +++++++-- .../apache/spark/streaming/StateSpec.scala | 181 ++++++++++++++++++ .../spark/streaming/TrackStateSpec.scala | 111 ----------- .../dstream/EmittedRecordsDStream.scala | 23 ++- .../dstream/PairDStreamFunctions.scala | 47 ++++- .../spark/streaming/rdd/TrackStateRDD.scala | 4 + .../streaming/TrackStateByKeySuite.scala | 137 +++++++++---- .../streaming/rdd/TrackStateRDDSuite.scala | 9 +- 9 files changed, 427 insertions(+), 173 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 50f637f9762f..f0b6ead255e0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -69,7 +69,7 @@ object StatefulNetworkWordCount { } val stateDstream = wordDstream.trackStateByKey( - TrackStateSpec(trackStateFunc).initialState(initialRDD)) + StateSpec(trackStateFunc).initialState(initialRDD)) stateDstream.print() ssc.start() ssc.awaitTermination() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index 061eb57c8c79..40fedf7a535d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -17,50 +17,108 @@ package org.apache.spark.streaming +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental + /** + * :: Experimental :: * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of - * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] and - * [[org.apache.spark.streaming.api.java.JavaPairDStream]]. + * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * + * Scala example of using `State`: * {{{ + * def trackStateFunc(key: String, data: Option[Int], wrappedState: State[Int]): Option[Int] = { + * + * // Check if state exists + * if (state.exists) { + * + * val existingState = wrappedState.get // Get the existing state + * + * val shouldRemove = ... // Decide whether to remove the state + * + * if (shouldRemove) { + * + * wrappedState.remove() // Remove the state + * + * } else { + * + * val newState = ... + * wrappedState(newState) // Set the new state + * + * } + * } else { + * + * val initialState = ... + * state.update(initialState) // Set the initial state + * + * } + * } * * }}} + * + * Java example: + * {{{ + * TODO(@zsxwing) + * }}} */ +@Experimental sealed abstract class State[S] { /** Whether the state already exists */ def exists(): Boolean /** - * Get the state if it exists, otherwise wise it will throw an exception. + * Get the state if it exists, otherwise it will throw `java.util.NoSuchElementException`. * Check with `exists()` whether the state exists or not before calling `get()`. + * + * @throws java.util.NoSuchElementException If the state does not exist. */ def get(): S /** - * Update the state with a new value. Note that you cannot update the state if the state is - * timing out (that is, `isTimingOut() return true`, or if the state has already been removed by - * `remove()`. + * Update the state with a new value. + * + * State cannot be updated if it has been already removed (that is, `remove()` has already been + * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). + * + * @throws java.lang.IllegalArgumentException If the state has already been removed, or is + * going to be removed */ def update(newState: S): Unit - /** Remove the state if it exists. */ + /** + * Remove the state if it exists. + * + * State cannot be updated if it has been already removed (that is, `remove()` has already been + * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). + */ def remove(): Unit - /** Is the state going to be timed out by the system after this batch interval */ + /** + * Whether the state is timing out and going to be removed by the system after the current batch. + * This timeou can occur if timeout duration has been specified in the + * [[org.apache.spark.streaming.StateSpec StatSpec]] and the key has not received any new data + * for that timeout duration. + */ def isTimingOut(): Boolean - @inline final def getOption(): Option[S] = Option(get()) - - /** Get the state if it exists, otherwise return the default value */ - @inline final def getOrElse[S1 >: S](default: => S1): S1 = { - if (exists) this.get else default - } + /** + * Get the state as an [[scala.Option]]. It will be `Some(state)` if it exists, otherwise `None`. + */ + @inline final def getOption(): Option[S] = if (exists) Some(get()) else None @inline final override def toString(): String = { getOption.map { _.toString }.getOrElse("") } } +private[streaming] +object State { + implicit def toOption[S](state: State[S]): Option[S] = state.getOption() +} + /** Internal implementation of the [[State]] interface */ private[streaming] class StateImpl[S] extends State[S] { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala new file mode 100644 index 000000000000..fd7a2bef863a --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.reflect.ClassTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.rdd.RDD +import org.apache.spark.{HashPartitioner, Partitioner} + + +/** + * :: Experimental :: + * Abstract class representing all the specifications of the DStream transformation + * `trackStateByKey` operation of a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or + * [[org.apache.spark.streaming.StateSpec StateSpec.create()]] to create instances of + * this class. + * + * Example in Scala: + * {{{ + * val spec = StateSpec(trackingFunction).numPartitions(10) + * + * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) + * }}} + * + * Example in Java: + * {{{ + * StateStateSpec[StateType, EmittedDataType] spec = + * StateStateSpec.create[StateType, EmittedDataType](trackingFunction).numPartition(10); + * + * JavaDStream[EmittedDataType] emittedRecordDStream = + * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); + * }}} + */ +@Experimental +sealed abstract class StateSpec[K, V, S, T] extends Serializable { + + /** Set the RDD containing the initial states that will be used by `trackStateByKey`*/ + def initialState(rdd: RDD[(K, S)]): this.type + + /** Set the RDD containing the initial states that will be used by `trackStateByKey`*/ + def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type + + /** + * Set the number of partitions by which the state RDDs generated by `trackStateByKey` + * will be partitioned. Hash partitioning will be used on the + */ + def numPartitions(numPartitions: Int): this.type + + /** + * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be + * be partitioned. + */ + def partitioner(partitioner: Partitioner): this.type + + /** + * Set the duration after which the state of an idle key will be removed. A key and its state is + * considered idle if it has not received any data for at least the given duration. The state + * tracking function will be called one final time on the idle states that are going to be + * removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set + * to `true` in that call. + */ + def timeout(idleDuration: Duration): this.type +} + + +/** + * :: Experimental :: + * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]] + * that is used for specifying the parameters of the DStream transformation + * `trackStateByKey` operation of a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * + * Example in Scala: + * {{{ + * val spec = StateSpec(trackingFunction).numPartitions(10) + * + * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) + * }}} + * + * Example in Java: + * {{{ + * StateStateSpec[StateType, EmittedDataType] spec = + * StateStateSpec.create[StateType, EmittedDataType](trackingFunction).numPartition(10); + * + * JavaDStream[EmittedDataType] emittedRecordDStream = + * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); + * }}} + */ +@Experimental +object StateSpec { + /** + * + * @param trackingFunction + * @tparam KeyType Class of keys + * @tparam ValueType + * @tparam StateType + * @tparam EmittedType + * @return + */ + def apply[KeyType, ValueType, StateType, EmittedType]( + trackingFunction: (KeyType, Option[ValueType], State[StateType]) => Option[EmittedType] + ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { + new StateSpecImpl[KeyType, ValueType, StateType, EmittedType](trackingFunction) + } + + def create[KeyType, ValueType, StateType, EmittedType]( + trackingFunction: (KeyType, Option[ValueType], State[StateType]) => Option[EmittedType] + ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { + apply(trackingFunction) + } +} + + +/** Internal implementation of [[org.apache.spark.streaming.StateSpec]] interface. */ +private[streaming] +case class StateSpecImpl[K, V, S: ClassTag, T: ClassTag]( + function: (K, Option[V], State[S]) => Option[T]) extends StateSpec[K, V, S, T] { + + require(function != null) + + @volatile private var partitioner: Partitioner = null + @volatile private var initialStateRDD: RDD[(K, S)] = null + @volatile private var timeoutInterval: Duration = null + + def initialState(rdd: RDD[(K, S)]): this.type = { + this.initialStateRDD = rdd + this + } + + def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { + this.initialStateRDD = javaPairRDD.rdd + this + } + + + def numPartitions(numPartitions: Int): this.type = { + this.partitioner(new HashPartitioner(numPartitions)) + this + } + + def partitioner(partitioner: Partitioner): this.type = { + this.partitioner = partitioner + this + } + + def timeout(interval: Duration): this.type = { + this.timeoutInterval = interval + this + } + + // ================= Private Methods ================= + + private[streaming] def getFunction(): (K, Option[V], State[S]) => Option[T] = function + + private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD) + + private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner) + + private[streaming] def getTimeoutInterval(): Option[Duration] = Option(timeoutInterval) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala deleted file mode 100644 index 4d639a8f85a3..000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/TrackStateSpec.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import scala.reflect.ClassTag - -import org.apache.spark.{HashPartitioner, Partitioner} -import org.apache.spark.api.java.JavaPairRDD -import org.apache.spark.rdd.RDD - - -/** - * Abstract class having all the specifications of DStream.trackStateByKey(). - * Use the `TrackStateSpec.create()` or `TrackStateSpec.create()` to create instances of this class. - * - * {{{ - * TrackStateSpec(trackingFunction) // in Scala - * TrackStateSpec.create(trackingFunction) // in Java - * }}} - */ -sealed abstract class TrackStateSpec[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag] - extends Serializable { - - def initialState(rdd: RDD[(K, S)]): this.type - def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type - - def numPartitions(numPartitions: Int): this.type - def partitioner(partitioner: Partitioner): this.type - - def timeout(interval: Duration): this.type -} - - -/** Builder object for creating instances of TrackStateSpec */ -object TrackStateSpec { - - def apply[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - trackingFunction: (K, Option[V], State[S]) => Option[T]): TrackStateSpec[K, V, S, T] = { - new TrackStateSpecImpl[K, V, S, T](trackingFunction) - } - - def create[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - trackingFunction: (K, Option[V], State[S]) => Option[T]): TrackStateSpec[K, V, S, T] = { - apply(trackingFunction) - } -} - - -/** Internal implementation of [[TrackStateSpec]] interface */ -private[streaming] -case class TrackStateSpecImpl[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - function: (K, Option[V], State[S]) => Option[T]) extends TrackStateSpec[K, V, S, T] { - - require(function != null) - - @volatile private var partitioner: Partitioner = null - @volatile private var initialStateRDD: RDD[(K, S)] = null - @volatile private var timeoutInterval: Duration = null - - - def initialState(rdd: RDD[(K, S)]): this.type = { - this.initialStateRDD = rdd - this - } - - def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { - this.initialStateRDD = javaPairRDD.rdd - this - } - - - def numPartitions(numPartitions: Int): this.type = { - this.partitioner(new HashPartitioner(numPartitions)) - this - } - - def partitioner(partitioner: Partitioner): this.type = { - this.partitioner = partitioner - this - } - - def timeout(interval: Duration): this.type = { - this.timeoutInterval = interval - this - } - - // ================= Private Methods ================= - - private[streaming] def getFunction(): (K, Option[V], State[S]) => Option[T] = function - - private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD) - - private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner) - - private[streaming] def getTimeoutInterval(): Option[Duration] = Option(timeoutInterval) -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala index de357ea5664a..c1c07dc48e4b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala @@ -20,23 +20,36 @@ package org.apache.spark.streaming.dstream import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} - -abstract class EmittedRecordsDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( +/** + * :: Experimental :: + * DStream representing the stream of records emitted after the `trackStateByKey` operation + * on a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] in Scala. + * Additionally, it also gives access to the stream of state snapshots, that is, the state data of + * all keys after a batch has updated them. + * + * @tparam K Class of the state key + * @tparam S Class of the state data + * @tparam T Class of the emitted records + */ +@Experimental +sealed abstract class EmittedRecordsDStream[K, S, T: ClassTag]( ssc: StreamingContext) extends DStream[T](ssc) { + /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ def stateSnapshots(): DStream[(K, S)] } - +/** Internal implementation of the [[EmittedRecordsDStream]] */ private[streaming] class EmittedRecordsDStreamImpl[ K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( trackStateDStream: TrackStateDStream[K, V, S, T]) - extends EmittedRecordsDStream[K, V, S, T](trackStateDStream.context) { + extends EmittedRecordsDStream[K, S, T](trackStateDStream.context) { override def slideDuration: Duration = trackStateDStream.slideDuration @@ -64,7 +77,7 @@ private[streaming] class EmittedRecordsDStreamImpl[ * @tparam T Type of the eiitted records */ private[streaming] class TrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - parent: DStream[(K, V)], spec: TrackStateSpecImpl[K, V, S, T]) + parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, T]) extends DStream[TrackStateRDDRecord[K, S, T]](parent.context) { persist(StorageLevel.MEMORY_ONLY) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index cdc668a4ef8f..9c98215d370f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -24,9 +24,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} +import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.streaming.StreamingContext.rddToFileName -import org.apache.spark.streaming.{Duration, Time, TrackStateSpec, TrackStateSpecImpl} +import org.apache.spark.streaming._ import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} import org.apache.spark.{HashPartitioner, Partitioner} @@ -349,13 +350,45 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) ) } - /** TODO: Add scala docs */ - def trackStateByKey[S: ClassTag, T: ClassTag]( - spec: TrackStateSpec[K, V, S, T]): EmittedRecordsDStream[K, V, S, T] = { - new EmittedRecordsDStreamImpl[K, V, S, T]( - new TrackStateDStream[K, V, S, T]( + /** + * :: Experimental :: + * Return a new DStream of data generated by combining the key-value data in `this` stream + * with a continuously updated per-key state. The user-provided state tracking function is + * applied on each keyed data item along with its corresponding state. The function can choose to + * update/remove the state and return a transformed data, which forms the + * [[org.apache.spark.streaming.dstream.EmittedRecordsDStream]]. + * + * The specifications of this transformation is made through the + * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there + * are a number of optional parameters - initial state data, number of partitions, timeouts, etc. + * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for more details. + * + * Scala example of using `trackStateByKey`: + * {{{ + * def trackingFunction(key: String, data: Option[Int], wrappedState: State[Int]): Option[Int] = { + * // Check state exists, accordingly update/remove state and return data to emit + * } + * + * val spec = StateSpec(trackingFunction).numPartitions(10) + * + * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) + * + * }}} + * + * Java example of using `trackStateByKey`: + * {{{ + * TODO(@zsxwing) + * }}} + * + * @tparam StateType + */ + @Experimental + def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag]( + spec: StateSpec[K, V, StateType, StateType]): EmittedRecordsDStream[K, StateType, EmittedType] = { + new EmittedRecordsDStreamImpl[K, V, StateType, EmittedType]( + new TrackStateDStream[K, V, StateType, EmittedType]( self, - spec.asInstanceOf[TrackStateSpecImpl[K, V, S, T]] + spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]] ) ) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index 3183ebce14d0..8de640b6408b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -156,6 +156,10 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: prevStateRDD = null partitionedDataRDD = null } + + def setFullScan(): Unit = { + doFullScan = true + } } private[streaming] object TrackStateRDD { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index ea9284484f5c..a7e6548d03d7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.streaming import java.io.File -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.util.{Utils, ManualClock} -import org.apache.spark.{SparkFunSuite, SparkConf, SparkContext} -import org.apache.spark.streaming.dstream.{EmittedRecordsDStream, DStream} +import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { @@ -36,7 +35,9 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef private val batchDuration = Seconds(1) before { - StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + StreamingContext.getActive().foreach { + _.stop(stopSparkContext = false) + } checkpointDir = Utils.createTempDir("checkpoint") ssc = new StreamingContext(sc, batchDuration) @@ -44,7 +45,9 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef } after { - StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + StreamingContext.getActive().foreach { + _.stop(stopSparkContext = false) + } } override def beforeAll(): Unit = { @@ -53,7 +56,72 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef sc = new SparkContext(conf) } - test("basic operation") { + test("state - get, exists, update, remove, ") { + val state = new StateImpl[Int]() + + def testState( + expectedData: Option[Int], + shouldBeUpdated: Boolean = false, + shouldBeRemoved: Boolean = false, + shouldBeTimingOut: Boolean = false + ): Unit = { + if (expectedData.isDefined) { + assert(state.exists) + assert(state.get() === expectedData.get) + assert(state.getOption() === expectedData) + assert(state.getOrElse(-1) === expectedData.get) // test implicit Option conversion + } else { + assert(!state.exists) + intercept[NoSuchElementException] { + state.get() + } + assert(state.getOption() === None) + assert(state.getOrElse(-1) === -1) // test implicit Option conversion + } + + assert(state.isTimingOut() === shouldBeTimingOut) + if (shouldBeTimingOut) { + intercept[IllegalArgumentException] { + state.remove() + } + intercept[IllegalArgumentException] { + state.update(-1) + } + } + + assert(state.isUpdated() === shouldBeUpdated) + + assert(state.isRemoved() === shouldBeRemoved) + if (shouldBeRemoved) { + intercept[IllegalArgumentException] { + state.remove() + } + intercept[IllegalArgumentException] { + state.update(-1) + } + } + } + + testState(None) + + state.wrap(None) + testState(None) + + state.wrap(Some(1)) + testState(Some(1)) + + state.update(2) + testState(Some(2), shouldBeRemoved = true) + + state.remove() + testState(None, shouldBeUpdated = true) + + state.wrapTiminoutState(3) + testState(Some(3), shouldBeTimingOut = true) + } + + + test("trackStateByKey - basic operations") { val inputData = Seq( Seq(), @@ -87,17 +155,17 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef Seq(("a", 5), ("b", 3), ("c", 1)) ) - // state maintains running count, key string doubled and returned - val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { - val sum = value.getOrElse(0) + state.getOrElse(0) - state.update(sum) - Some(key * 2) + // state maintains running count, key string doubled and returned + val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOrElse(0) + state.update(sum) + Some(key * 2) } - testOperation(inputData, TrackStateSpec(trackStateFunc), outputData, stateData) + testOperation(inputData, StateSpec(trackStateFunc), outputData, stateData) } - test("states as emitted records") { + test("trackStateByKey - states as emitted records") { val inputData = Seq( Seq(), @@ -138,10 +206,10 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef Some(output) } - testOperation(inputData, TrackStateSpec(trackStateFunc), outputData, stateData) + testOperation(inputData, StateSpec(trackStateFunc), outputData, stateData) } - test("initial states, with nothing emitted") { + test("trackStateByKey - initial states, with nothing emitted") { val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)) @@ -176,20 +244,20 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef None.asInstanceOf[Option[Int]] } - val trackStateSpec = TrackStateSpec(trackStateFunc).initialState(sc.makeRDD(initialState)) + val trackStateSpec = StateSpec(trackStateFunc).initialState(sc.makeRDD(initialState)) testOperation(inputData, trackStateSpec, outputData, stateData) } - test("state removing") { + test("trackStateByKey - state removing") { val inputData = Seq( Seq(), Seq("a"), - Seq("a", "b"), // a will be removed - Seq("a", "b", "c"), // b will be removed - Seq("a", "b", "c"), // a and c will be removed - Seq("a", "b"), // b will be removed - Seq("a"), // a will be removed + Seq("a", "b"), // a will be removed + Seq("a", "b", "c"), // b will be removed + Seq("a", "b", "c"), // a and c will be removed + Seq("a", "b"), // b will be removed + Seq("a"), // a will be removed Seq() ) @@ -228,19 +296,19 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef } } - testOperation(inputData, TrackStateSpec(trackStateFunc).numPartitions(1), outputData, stateData) + testOperation(inputData, StateSpec(trackStateFunc).numPartitions(1), outputData, stateData) } - test("state timing out") { + test("trackStateByKey - state timing out") { val inputData = Seq( Seq("a", "b", "c"), Seq("a", "b"), Seq("a"), - Seq(), // c will time out - Seq(), // b will time out - Seq("a") // a will not time out - ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active + Seq(), // c will time out + Seq(), // b will time out + Seq("a") // a will not time out + ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { if (value.isDefined) { @@ -254,13 +322,15 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef } val (collectedOutputs, collectedStateSnapshots) = getOperationOutput( - inputData, TrackStateSpec(trackStateFunc).timeout(Seconds(3)), 20) + inputData, StateSpec(trackStateFunc).timeout(Seconds(3)), 20) // b and c should be emitted once each, when they were marked as expired assert(collectedOutputs.flatten.sorted === Seq("b", "c")) // States for a, b, c should be defined at one point of time - assert(collectedStateSnapshots.exists { _.toSet == Set(("a", 1), ("b", 1), ("c", 1)) }) + assert(collectedStateSnapshots.exists { + _.toSet == Set(("a", 1), ("b", 1), ("c", 1)) + }) // Finally state should be defined only for a assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) @@ -269,7 +339,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], - trackStateSpec: TrackStateSpec[K, Int, S, T], + trackStateSpec: StateSpec[K, Int, S, T], expectedOutputs: Seq[Seq[T]], expectedStateSnapshots: Seq[Seq[(K, S)]] ): Unit = { @@ -283,7 +353,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], - trackStateSpec: TrackStateSpec[K, Int, S, T], + trackStateSpec: StateSpec[K, Int, S, T], numBatches: Int ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { @@ -319,4 +389,5 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef ) } } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index c309ff7cc8d8..ff199156e935 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -46,7 +46,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { assert(rdd.partitioner === Some(partitioner)) } - test("updating state on existing TrackStateRDD") { + test("states generated by TrackStateRDD") { val initStates = Seq(("k1", 0), ("k2", 0)) val initTime = 123 val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet @@ -148,13 +148,17 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { ) } + + private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( testStateRDD: TrackStateRDD[K, V, S, T], newDataRDD: RDD[(K, V)], trackStateFunc: (K, Option[V], State[S]) => Option[T], currentTime: Long, expectedStates: Set[(K, S, Int)], - expectedEmittedRecords: Set[T]): TrackStateRDD[K, V, S, T] = { + expectedEmittedRecords: Set[T], + doFullScan: Boolean = false + ): TrackStateRDD[K, V, S, T] = { val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) { newDataRDD.partitionBy(testStateRDD.partitioner.get) @@ -164,6 +168,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { val newStateRDD = new TrackStateRDD[K, V, S, T]( testStateRDD, newDataRDD, trackStateFunc, currentTime, None) + if (doFullScan) newStateRDD.setFullScan() // Persist to make sure that it gets computed only once and we can track precisely how many // state keys the computing touched From df3bb1b32f7468a56a95b74545a6f3c5e2781f22 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 6 Nov 2015 15:57:49 -0800 Subject: [PATCH 20/26] Small fixes --- .../src/main/scala/org/apache/spark/streaming/StateSpec.scala | 2 +- .../apache/spark/streaming/dstream/PairDStreamFunctions.scala | 2 +- .../scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index fd7a2bef863a..53feed709751 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -134,7 +134,7 @@ object StateSpec { /** Internal implementation of [[org.apache.spark.streaming.StateSpec]] interface. */ private[streaming] -case class StateSpecImpl[K, V, S: ClassTag, T: ClassTag]( +case class StateSpecImpl[K, V, S, T]( function: (K, Option[V], State[S]) => Option[T]) extends StateSpec[K, V, S, T] { require(function != null) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 9c98215d370f..167ef399707e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -384,7 +384,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) */ @Experimental def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag]( - spec: StateSpec[K, V, StateType, StateType]): EmittedRecordsDStream[K, StateType, EmittedType] = { + spec: StateSpec[K, V, StateType, EmittedType]): EmittedRecordsDStream[K, StateType, EmittedType] = { new EmittedRecordsDStreamImpl[K, V, StateType, EmittedType]( new TrackStateDStream[K, V, StateType, EmittedType]( self, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index 8de640b6408b..440f2857bbdb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -151,7 +151,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)} } - override def clearDependencies() { + override def clearDependencies(): Unit = { super.clearDependencies() prevStateRDD = null partitionedDataRDD = null From b28179f38f2dd2856c4afcfdee995c5c694b1929 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 6 Nov 2015 17:57:49 -0800 Subject: [PATCH 21/26] Minor changes from PR feedback --- .../streaming/dstream/EmittedRecordsDStream.scala | 1 + .../apache/spark/streaming/rdd/TrackStateRDD.scala | 14 ++++++-------- .../org/apache/spark/streaming/util/StateMap.scala | 12 +++++------- .../org/apache/spark/streaming/StateMapSuite.scala | 2 +- .../spark/streaming/rdd/TrackStateRDDSuite.scala | 5 ++--- 5 files changed, 15 insertions(+), 19 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala index c1c07dc48e4b..afa961f22e27 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala @@ -59,6 +59,7 @@ private[streaming] class EmittedRecordsDStreamImpl[ trackStateDStream.getOrCompute(validTime).map { _.flatMap[T] { _.emittedRecords } } } + /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ def stateSnapshots(): DStream[(K, S)] = { trackStateDStream.flatMap[(K, S)] { _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index 440f2857bbdb..fb2715263d75 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -132,14 +132,12 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: emittedRecords ++= emittedRecord } - if (doFullScan) { - if (timeoutThresholdTime.isDefined) { - newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => - wrappedState.wrapTiminoutState(state) - val emittedRecord = trackingFunction(key, None, wrappedState) - emittedRecords ++= emittedRecord - newStateMap.remove(key) - } + if (doFullScan && timeoutThresholdTime.isDefined) { + newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => + wrappedState.wrapTiminoutState(state) + val emittedRecord = trackingFunction(key, None, wrappedState) + emittedRecords ++= emittedRecord + newStateMap.remove(key) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index b6b629a9000b..ff202db6e9e6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -71,13 +71,11 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa override def get(key: K): Option[S] = None override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty override def getAll(): Iterator[(K, S, Long)] = Iterator.empty - override def copy(): StateMap[K, S] = new EmptyStateMap[K, S] + override def copy(): StateMap[K, S] = this override def remove(key: K): Unit = { } override def toDebugString(): String = "" } - - /** Implementation of StateMap based on Spark's OpenHashMap */ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( @transient @volatile var parentStateMap: StateMap[K, S], @@ -118,10 +116,10 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( !deltaMap.contains(key) } - val updatedStates = deltaMap.iterator.flatMap { case (key, stateInfo) => - if (! stateInfo.deleted && stateInfo.updateTime < threshUpdatedTime) { - Some((key, stateInfo.data, stateInfo.updateTime)) - } else None + val updatedStates = deltaMap.iterator.filter { case (_, stateInfo) => + !stateInfo.deleted && stateInfo.updateTime < threshUpdatedTime + }.map { case (key, stateInfo) => + (key, stateInfo.data, stateInfo.updateTime) } oldStates ++ updatedStates } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index 15b05d8ed259..28738c4414aa 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -35,7 +35,7 @@ class StateMapSuite extends SparkFunSuite { assert(map.getByTime(10000).isEmpty) assert(map.getAll().isEmpty) map.remove(1) // no exception - assert(map.copy().getAll().isEmpty) + assert(map.copy().eq(this)) } test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index ff199156e935..0a4d0070dfaf 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -108,7 +108,6 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { newStateRDD } - // Test no-op, no state should change testStateUpdates(initStateRDD, Seq(), initStateWthTime) // should not scan any state testStateUpdates( @@ -148,8 +147,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { ) } - - + /** Assert whether the `trackStateByKey` operation generates expected results */ private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( testStateRDD: TrackStateRDD[K, V, S, T], newDataRDD: RDD[(K, V)], @@ -177,6 +175,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { newStateRDD } + /** Assert whether the [[TrackStateRDD]] has the expected state ad emitted records */ private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( trackStateRDD: TrackStateRDD[K, V, S, T], expectedStates: Set[(K, S, Int)], From a78130d502576d437e859261137bad67e15725ae Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 6 Nov 2015 19:41:35 -0800 Subject: [PATCH 22/26] Updated some docs --- .../org/apache/spark/streaming/State.scala | 10 ++--- .../apache/spark/streaming/StateSpec.scala | 39 ++++++++++++------ .../spark/streaming/util/StateMap.scala | 41 +++++++++++++++++-- 3 files changed, 70 insertions(+), 20 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index 40fedf7a535d..e9164c8ae5ad 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -129,26 +129,26 @@ private[streaming] class StateImpl[S] extends State[S] { private var removed: Boolean = false // ========= Public API ========= - def exists(): Boolean = { + override def exists(): Boolean = { defined } - def get(): S = { + override def get(): S = { state } - def update(newState: S): Unit = { + override def update(newState: S): Unit = { require(!removed, "Cannot update the state after it has been removed") require(!timingOut, "Cannot update the state that is timing out") state = newState updated = true } - def isTimingOut(): Boolean = { + override def isTimingOut(): Boolean = { timingOut } - def remove(): Unit = { + override def remove(): Unit = { require(!timingOut, "Cannot remove the state that is timing out") require(!removed, "Cannot remove the state that has already been removed") removed = true diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index 53feed709751..484f823ddc97 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -110,13 +110,16 @@ sealed abstract class StateSpec[K, V, S, T] extends Serializable { @Experimental object StateSpec { /** - * - * @param trackingFunction - * @tparam KeyType Class of keys - * @tparam ValueType - * @tparam StateType - * @tparam EmittedType - * @return + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * `trackStateByKey` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * @param trackingFunction The function applied on every data item to manage the associated state + * and generate the emitted data and + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data */ def apply[KeyType, ValueType, StateType, EmittedType]( trackingFunction: (KeyType, Option[ValueType], State[StateType]) => Option[EmittedType] @@ -124,6 +127,18 @@ object StateSpec { new StateSpecImpl[KeyType, ValueType, StateType, EmittedType](trackingFunction) } + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * `trackStateByKey` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * @param trackingFunction The function applied on every data item to manage the associated state + * and generate the emitted data and + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data + */ def create[KeyType, ValueType, StateType, EmittedType]( trackingFunction: (KeyType, Option[ValueType], State[StateType]) => Option[EmittedType] ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { @@ -143,28 +158,28 @@ case class StateSpecImpl[K, V, S, T]( @volatile private var initialStateRDD: RDD[(K, S)] = null @volatile private var timeoutInterval: Duration = null - def initialState(rdd: RDD[(K, S)]): this.type = { + override def initialState(rdd: RDD[(K, S)]): this.type = { this.initialStateRDD = rdd this } - def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { + override def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { this.initialStateRDD = javaPairRDD.rdd this } - def numPartitions(numPartitions: Int): this.type = { + override def numPartitions(numPartitions: Int): this.type = { this.partitioner(new HashPartitioner(numPartitions)) this } - def partitioner(partitioner: Partitioner): this.type = { + override def partitioner(partitioner: Partitioner): this.type = { this.partitioner = partitioner this } - def timeout(interval: Duration): this.type = { + override def timeout(interval: Duration): this.type = { this.timeoutInterval = interval this } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index ff202db6e9e6..337070821eff 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -166,15 +166,21 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = deltaChainThreshold) } + /** Whether the delta chain lenght is long enough that it should be compacted */ def shouldCompact: Boolean = { deltaChainLength >= deltaChainThreshold } + /** Length of the delta chains of this map */ def deltaChainLength: Int = parentStateMap match { case map: OpenHashMapBasedStateMap[_, _] => map.deltaChainLength + 1 case _ => 0 } + /** + * Approximate number of keys in the map. This is an overestimation that is mainly used to + * reserve capacity in a new map at delta compaction time. + */ def approxSize: Int = deltaMap.size + { parentStateMap match { case s: OpenHashMapBasedStateMap[_, _] => s.approxSize @@ -182,6 +188,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( } } + /** Get all the data of this map as string formatted as a tree based on the delta depth */ override def toDebugString(): String = { val tabs = if (deltaChainLength > 0) { (" " * (deltaChainLength - 1)) + "+--- " @@ -193,11 +200,16 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( s"[${System.identityHashCode(this)}, ${System.identityHashCode(parentStateMap)}]" } - private def writeObject(outputStream: ObjectOutputStream): Unit = { + /** + * Serialize the map data. Besides serialization, this method actually compact the deltas + * (if needed) in a single pass over all the data in the map. + */ + private def writeObject(outputStream: ObjectOutputStream): Unit = { + // Write all the non-transient fields, especially class tags, etc. outputStream.defaultWriteObject() - // Write the deltaMap + // Write the data in the delta of this state map outputStream.writeInt(deltaMap.size) val deltaMapIterator = deltaMap.iterator var deltaMapCount = 0 @@ -209,7 +221,8 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( } assert(deltaMapCount == deltaMap.size) - // Write the parentStateMap while consolidating + // Write the data in the parent state map while copying the data into a new parent map for + // compaction (if needed) val doCompaction = shouldCompact val newParentSessionStore = if (doCompaction) { val initCapacity = if (approxSize > 0) approxSize else 64 @@ -220,6 +233,8 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( var parentSessionCount = 0 + // First write the approximate size of the data to be written, so that readObject can + // allocate appropriately sized OpenHashMap. outputStream.writeInt(approxSize) while(iterOfActiveSessions.hasNext) { @@ -235,6 +250,8 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( key, StateInfo(state, updateTime, deleted = false)) } } + + // Write the final limit marking object with the correct count of records written. val limiterObj = new LimitMarker(parentSessionCount) outputStream.writeObject(limiterObj) if (doCompaction) { @@ -242,9 +259,13 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( } } + /** Deserialize the map data. */ private def readObject(inputStream: ObjectInputStream): Unit = { + + // Read the non-transient fields, especially class tags, etc. inputStream.defaultReadObject() + // Read the data of the delta val deltaMapSize = inputStream.readInt() deltaMap = new OpenHashMap[K, StateInfo[S]]() var deltaMapCount = 0 @@ -255,10 +276,15 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( deltaMapCount += 1 } + + // Read the data of the parent map. Keep reading records, until the limiter is reached + // First read the approximate number of records to expect and allocate properly size + // OpenHashMap val parentSessionStoreSizeHint = inputStream.readInt() val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold) + // Read the records until the limit marking object has been reached var parentSessionLoopDone = false while(!parentSessionLoopDone) { val obj = inputStream.readObject() @@ -278,8 +304,13 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( } } +/** + * Companion object of [[OpenHashMapBasedStateMap]] having associated helper + * classes and methods + */ private[streaming] object OpenHashMapBasedStateMap { + /** Internal class to represent the state information */ case class StateInfo[S]( var data: S = null.asInstanceOf[S], var updateTime: Long = -1, @@ -296,6 +327,10 @@ private[streaming] object OpenHashMapBasedStateMap { } } + /** + * Internal class to represent a marker the demarkate the the end of all state data in the + * serialized bytes. + */ class LimitMarker(val num: Int) extends Serializable val DELTA_CHAIN_LENGTH_THRESHOLD = 20 From fb5a296ac2c82eec16f2449267f94e98a46c54be Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 10 Nov 2015 06:38:12 -0800 Subject: [PATCH 23/26] Addressed PR comments --- .../org/apache/spark/streaming/State.scala | 44 +++--- .../apache/spark/streaming/StateSpec.scala | 58 +++++--- .../dstream/EmittedRecordsDStream.scala | 115 --------------- .../dstream/PairDStreamFunctions.scala | 33 ++--- .../streaming/dstream/TrackStateDStream.scala | 134 ++++++++++++++++++ .../spark/streaming/rdd/TrackStateRDD.scala | 68 +++++---- .../spark/streaming/util/StateMap.scala | 4 +- .../spark/streaming/StateMapSuite.scala | 2 +- .../streaming/TrackStateByKeySuite.scala | 96 ++++++++++--- .../streaming/rdd/TrackStateRDDSuite.scala | 12 +- 10 files changed, 318 insertions(+), 248 deletions(-) delete mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index e9164c8ae5ad..7dd1b72f8049 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -29,31 +29,23 @@ import org.apache.spark.annotation.Experimental * * Scala example of using `State`: * {{{ - * def trackStateFunc(key: String, data: Option[Int], wrappedState: State[Int]): Option[Int] = { - * + * // A tracking function that maintains an integer state and return a String + * def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = { * // Check if state exists * if (state.exists) { - * - * val existingState = wrappedState.get // Get the existing state - * - * val shouldRemove = ... // Decide whether to remove the state - * + * val existingState = state.get // Get the existing state + * val shouldRemove = ... // Decide whether to remove the state * if (shouldRemove) { - * - * wrappedState.remove() // Remove the state - * + * state.remove() // Remove the state * } else { - * * val newState = ... - * wrappedState(newState) // Set the new state - * + * state.update(newState) // Set the new state * } * } else { - * * val initialState = ... * state.update(initialState) // Set the initial state - * * } + * ... // return something * } * * }}} @@ -98,7 +90,7 @@ sealed abstract class State[S] { /** * Whether the state is timing out and going to be removed by the system after the current batch. - * This timeou can occur if timeout duration has been specified in the + * This timeout can occur if timeout duration has been specified in the * [[org.apache.spark.streaming.StateSpec StatSpec]] and the key has not received any new data * for that timeout duration. */ @@ -114,16 +106,11 @@ sealed abstract class State[S] { } } -private[streaming] -object State { - implicit def toOption[S](state: State[S]): Option[S] = state.getOption() -} - /** Internal implementation of the [[State]] interface */ private[streaming] class StateImpl[S] extends State[S] { private var state: S = null.asInstanceOf[S] - private var defined: Boolean = true + private var defined: Boolean = false private var timingOut: Boolean = false private var updated: Boolean = false private var removed: Boolean = false @@ -134,13 +121,18 @@ private[streaming] class StateImpl[S] extends State[S] { } override def get(): S = { - state + if (defined) { + state + } else { + throw new NoSuchElementException("State is not set") + } } override def update(newState: S): Unit = { require(!removed, "Cannot update the state after it has been removed") require(!timingOut, "Cannot update the state that is timing out") state = newState + defined = true updated = true } @@ -151,6 +143,8 @@ private[streaming] class StateImpl[S] extends State[S] { override def remove(): Unit = { require(!timingOut, "Cannot remove the state that is timing out") require(!removed, "Cannot remove the state that has already been removed") + defined = false + updated = false removed = true } @@ -167,7 +161,7 @@ private[streaming] class StateImpl[S] extends State[S] { } /** - * Internal method to update the state data and reset internal flags in `this`. + * Update the internal data and flags in `this` to the given state option. * This method allows `this` object to be reused across many state records. */ def wrap(optionalState: Option[S]): Unit = { @@ -186,7 +180,7 @@ private[streaming] class StateImpl[S] extends State[S] { } /** - * Internal method to update the state data and reset internal flags in `this`. + * Update the internal data and flags in `this` to the given state that is going to be timed out. * This method allows `this` object to be reused across many state records. */ def wrapTiminoutState(newState: S): Unit = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index 484f823ddc97..0896f57c12bc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.rdd.RDD +import org.apache.spark.util.ClosureCleaner import org.apache.spark.{HashPartitioner, Partitioner} @@ -37,28 +38,33 @@ import org.apache.spark.{HashPartitioner, Partitioner} * * Example in Scala: * {{{ - * val spec = StateSpec(trackingFunction).numPartitions(10) + * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { + * ... + * } + * + * val spec = StateSpec.function(trackingFunction).numPartitions(10) * * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) * }}} * * Example in Java: * {{{ - * StateStateSpec[StateType, EmittedDataType] spec = - * StateStateSpec.create[StateType, EmittedDataType](trackingFunction).numPartition(10); + * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = + * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) + * .numPartition(10); * * JavaDStream[EmittedDataType] emittedRecordDStream = * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); * }}} */ @Experimental -sealed abstract class StateSpec[K, V, S, T] extends Serializable { +sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable { /** Set the RDD containing the initial states that will be used by `trackStateByKey`*/ - def initialState(rdd: RDD[(K, S)]): this.type + def initialState(rdd: RDD[(KeyType, StateType)]): this.type /** Set the RDD containing the initial states that will be used by `trackStateByKey`*/ - def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type + def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type /** * Set the number of partitions by which the state RDDs generated by `trackStateByKey` @@ -93,15 +99,20 @@ sealed abstract class StateSpec[K, V, S, T] extends Serializable { * * Example in Scala: * {{{ - * val spec = StateSpec(trackingFunction).numPartitions(10) + * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { + * ... + * } + * + * val spec = StateSpec.function(trackingFunction).numPartitions(10) * * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) * }}} * * Example in Java: * {{{ - * StateStateSpec[StateType, EmittedDataType] spec = - * StateStateSpec.create[StateType, EmittedDataType](trackingFunction).numPartition(10); + * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = + * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) + * .numPartition(10); * * JavaDStream[EmittedDataType] emittedRecordDStream = * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); @@ -115,16 +126,17 @@ object StateSpec { * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). * @param trackingFunction The function applied on every data item to manage the associated state - * and generate the emitted data and + * and generate the emitted data * @tparam KeyType Class of the keys * @tparam ValueType Class of the values * @tparam StateType Class of the states data * @tparam EmittedType Class of the emitted data */ - def apply[KeyType, ValueType, StateType, EmittedType]( - trackingFunction: (KeyType, Option[ValueType], State[StateType]) => Option[EmittedType] + def function[KeyType, ValueType, StateType, EmittedType]( + trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType] ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { - new StateSpecImpl[KeyType, ValueType, StateType, EmittedType](trackingFunction) + ClosureCleaner.clean(trackingFunction, checkSerializable = true) + new StateSpecImpl(trackingFunction) } /** @@ -133,16 +145,20 @@ object StateSpec { * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). * @param trackingFunction The function applied on every data item to manage the associated state - * and generate the emitted data and - * @tparam KeyType Class of the keys + * and generate the emitted data * @tparam ValueType Class of the values * @tparam StateType Class of the states data * @tparam EmittedType Class of the emitted data */ - def create[KeyType, ValueType, StateType, EmittedType]( - trackingFunction: (KeyType, Option[ValueType], State[StateType]) => Option[EmittedType] - ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { - apply(trackingFunction) + def function[ValueType, StateType, EmittedType]( + trackingFunction: (Option[ValueType], State[StateType]) => EmittedType + ): StateSpec[Any, ValueType, StateType, EmittedType] = { + ClosureCleaner.clean(trackingFunction, checkSerializable = true) + val wrappedFunction = + (time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => { + Some(trackingFunction(value, state)) + } + new StateSpecImpl[Any, ValueType, StateType, EmittedType](wrappedFunction) } } @@ -150,7 +166,7 @@ object StateSpec { /** Internal implementation of [[org.apache.spark.streaming.StateSpec]] interface. */ private[streaming] case class StateSpecImpl[K, V, S, T]( - function: (K, Option[V], State[S]) => Option[T]) extends StateSpec[K, V, S, T] { + function: (Time, K, Option[V], State[S]) => Option[T]) extends StateSpec[K, V, S, T] { require(function != null) @@ -186,7 +202,7 @@ case class StateSpecImpl[K, V, S, T]( // ================= Private Methods ================= - private[streaming] def getFunction(): (K, Option[V], State[S]) => Option[T] = function + private[streaming] def getFunction(): (Time, K, Option[V], State[S]) => Option[T] = function private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala deleted file mode 100644 index afa961f22e27..000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.dstream - -import scala.reflect.ClassTag - -import org.apache.spark._ -import org.apache.spark.annotation.Experimental -import org.apache.spark.rdd.{EmptyRDD, RDD} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming._ -import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} - -/** - * :: Experimental :: - * DStream representing the stream of records emitted after the `trackStateByKey` operation - * on a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] in Scala. - * Additionally, it also gives access to the stream of state snapshots, that is, the state data of - * all keys after a batch has updated them. - * - * @tparam K Class of the state key - * @tparam S Class of the state data - * @tparam T Class of the emitted records - */ -@Experimental -sealed abstract class EmittedRecordsDStream[K, S, T: ClassTag]( - ssc: StreamingContext) extends DStream[T](ssc) { - - /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ - def stateSnapshots(): DStream[(K, S)] -} - -/** Internal implementation of the [[EmittedRecordsDStream]] */ -private[streaming] class EmittedRecordsDStreamImpl[ - K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - trackStateDStream: TrackStateDStream[K, V, S, T]) - extends EmittedRecordsDStream[K, S, T](trackStateDStream.context) { - - override def slideDuration: Duration = trackStateDStream.slideDuration - - override def dependencies: List[DStream[_]] = List(trackStateDStream) - - override def compute(validTime: Time): Option[RDD[T]] = { - trackStateDStream.getOrCompute(validTime).map { _.flatMap[T] { _.emittedRecords } } - } - - /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ - def stateSnapshots(): DStream[(K, S)] = { - trackStateDStream.flatMap[(K, S)] { - _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable } - } -} - -/** - * A DStream that allows per-key state to be maintains, and arbitrary records to be generated - * based on updates to the state. - * - * @param parent Parent (key, value) stream that is the source - * @param spec Specifications of the trackStateByKey operation - * @tparam K Key type - * @tparam V Value type - * @tparam S Type of the state maintained - * @tparam T Type of the eiitted records - */ -private[streaming] class TrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, T]) - extends DStream[TrackStateRDDRecord[K, S, T]](parent.context) { - - persist(StorageLevel.MEMORY_ONLY) - - private val partitioner = spec.getPartitioner().getOrElse( - new HashPartitioner(ssc.sc.defaultParallelism)) - - private val trackingFunction = spec.getFunction() - - override def slideDuration: Duration = parent.slideDuration - - override def dependencies: List[DStream[_]] = List(parent) - - override val mustCheckpoint = true - - /** Method that generates a RDD for the given time */ - override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, T]]] = { - val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { - TrackStateRDD.createFromPairRDD[K, V, S, T]( - spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), - partitioner, - validTime.milliseconds - ) - } - val newDataRDD = parent.getOrCompute(validTime).get - val partitionedDataRDD = newDataRDD.partitionBy(partitioner) - val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => - (validTime - interval).milliseconds - } - - Some(new TrackStateRDD(prevStateRDD, partitionedDataRDD, - trackingFunction, validTime.milliseconds, timeoutThresholdTime)) - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 167ef399707e..9bdffa0c2d7e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -356,44 +356,37 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * with a continuously updated per-key state. The user-provided state tracking function is * applied on each keyed data item along with its corresponding state. The function can choose to * update/remove the state and return a transformed data, which forms the - * [[org.apache.spark.streaming.dstream.EmittedRecordsDStream]]. + * [[org.apache.spark.streaming.dstream.TrackStateDStream]]. * * The specifications of this transformation is made through the * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there * are a number of optional parameters - initial state data, number of partitions, timeouts, etc. * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for more details. * - * Scala example of using `trackStateByKey`: + * Example of using `trackStateByKey`: * {{{ - * def trackingFunction(key: String, data: Option[Int], wrappedState: State[Int]): Option[Int] = { - * // Check state exists, accordingly update/remove state and return data to emit + * def trackingFunction(data: Option[Int], wrappedState: State[Int]): String = { + * // Check if state exists, accordingly update/remove state and return transformed data * } * - * val spec = StateSpec(trackingFunction).numPartitions(10) + * val spec = StateSpec.function(trackingFunction).numPartitions(10) * - * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) - * - * }}} - * - * Java example of using `trackStateByKey`: - * {{{ - * TODO(@zsxwing) + * val trackStateDStream = keyValueDStream.trackStateByKey[Int, String](spec) * }}} * - * @tparam StateType + * @param spec Specification of this transformation + * @tparam StateType Class type of the state + * @tparam EmittedType Class type of the tranformed data return by the tracking function */ @Experimental def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag]( - spec: StateSpec[K, V, StateType, EmittedType]): EmittedRecordsDStream[K, StateType, EmittedType] = { - new EmittedRecordsDStreamImpl[K, V, StateType, EmittedType]( - new TrackStateDStream[K, V, StateType, EmittedType]( - self, - spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]] - ) + spec: StateSpec[K, V, StateType, EmittedType]): TrackStateDStream[K, StateType, EmittedType] = { + new TrackStateDStreamImpl[K, V, StateType, EmittedType]( + self, + spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]] ) } - /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala new file mode 100644 index 000000000000..701ccab3562a --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} + +/** + * :: Experimental :: + * DStream representing the stream of records emitted by the tracking function in the + * `trackStateByKey` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * Additionally, it also gives access to the stream of state snapshots, that is, the state data of + * all keys after a batch has updated them. + * + * @tparam KeyType Class of the state key + * @tparam StateType Class of the state data + * @tparam EmittedType Class of the emitted records + */ +@Experimental +sealed abstract class TrackStateDStream[KeyType, StateType, EmittedType: ClassTag]( + ssc: StreamingContext) extends DStream[EmittedType](ssc) { + + /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ + def stateSnapshots(): DStream[(KeyType, StateType)] +} + +/** Internal implementation of the [[TrackStateDStream]] */ +private[streaming] class TrackStateDStreamImpl[ + KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag]( + dataStream: DStream[(KeyType, ValueType)], + spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType]) + extends TrackStateDStream[KeyType, StateType, EmittedType](dataStream.context) { + + private val internalStream = + new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec) + + override def slideDuration: Duration = internalStream.slideDuration + + override def dependencies: List[DStream[_]] = List(internalStream) + + override def compute(validTime: Time): Option[RDD[EmittedType]] = { + internalStream.getOrCompute(validTime).map { _.flatMap[EmittedType] { _.emittedRecords } } + } + + /** + * Forward the checkpoint interval to the internal DStream that computes the state maps. This + * to make sure that this DStream does not get checkpointed, only the internal stream. + */ + override def checkpoint(checkpointInterval: Duration): DStream[EmittedType] = { + internalStream.checkpoint(checkpointInterval) + this + } + + /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ + def stateSnapshots(): DStream[(KeyType, StateType)] = { + internalStream.flatMap { + _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable } + } +} + +/** + * A DStream that allows per-key state to be maintains, and arbitrary records to be generated + * based on updates to the state. This is the main DStream that implements the `trackStateByKey` + * operation on DStreams. + * + * @param parent Parent (key, value) stream that is the source + * @param spec Specifications of the trackStateByKey operation + * @tparam K Key type + * @tparam V Value type + * @tparam S Type of the state maintained + * @tparam E Type of the emitted data + */ +private[streaming] +class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E]) + extends DStream[TrackStateRDDRecord[K, S, E]](parent.context) { + + persist(StorageLevel.MEMORY_ONLY) + + private val partitioner = spec.getPartitioner().getOrElse( + new HashPartitioner(ssc.sc.defaultParallelism)) + + private val trackingFunction = spec.getFunction() + + override def slideDuration: Duration = parent.slideDuration + + override def dependencies: List[DStream[_]] = List(parent) + + /** Enable automatic checkpointing */ + override val mustCheckpoint = true + + /** Method that generates a RDD for the given time */ + override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { + // Get the previous state or create a new empty state RDD + val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { + TrackStateRDD.createFromPairRDD[K, V, S, E]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, validTime + ) + } + + // Compute the new state RDD with previous state RDD and partitioned data RDD + parent.getOrCompute(validTime).map { dataRDD => + val partitionedDataRDD = dataRDD.partitionBy(partitioner) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds + } + new TrackStateRDD( + prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index fb2715263d75..ed7cea26d060 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -23,36 +23,22 @@ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.spark.rdd.{MapPartitionsRDD, RDD} -import org.apache.spark.streaming.{StateImpl, State} +import org.apache.spark.streaming.{Time, StateImpl, State} import org.apache.spark.streaming.util.{EmptyStateMap, StateMap} import org.apache.spark.util.Utils import org.apache.spark._ +/** + * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a + * sequence of records returned by the tracking function of `trackStateByKey`. + */ private[streaming] case class TrackStateRDDRecord[K, S, T]( - var stateMap: StateMap[K, S], var emittedRecords: Seq[T]) { - /* - private def writeObject(outputStream: ObjectOutputStream): Unit = { - outputStream.writeObject(stateMap) - outputStream.writeInt(emittedRecords.size) - val iterator = emittedRecords.iterator - while(iterator.hasNext) { - outputStream.writeObject(iterator.next) - } - } - - private def readObject(inputStream: ObjectInputStream): Unit = { - stateMap = inputStream.readObject().asInstanceOf[StateMap[K, S]] - val numEmittedRecords = inputStream.readInt() - val array = new Array[T](numEmittedRecords) - var i = 0 - while(i < numEmittedRecords) { - array(i) = inputStream.readObject().asInstanceOf[T] - } - emittedRecords = array.toSeq - }*/ -} - + var stateMap: StateMap[K, S], var emittedRecords: Seq[T]) +/** + * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state + * RDD, and a partitioned keyed-data RDD + */ private[streaming] class TrackStateRDDPartition( idx: Int, @transient private var prevStateRDD: RDD[_], @@ -75,14 +61,21 @@ private[streaming] class TrackStateRDDPartition( /** - * RDD storing the keyed-state of trackStateByKey and corresponding emitted records. - * Each partition of this RDD has a single record that contains a StateMap storing + * RDD storing the keyed-state of `trackStateByKey` and corresponding emitted records. + * Each partition of this RDD has a single record of type [[TrackStateRDDRecord]]. This contains a + * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the tracking + * function of `trackStateByKey`. + * @param prevStateRDD The previous TrackStateRDD on whose StateMap data `this` RDD will be created + * @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps + * in the `prevStateRDD` to create `this` RDD + * @param trackingFunction The function that will be used to update state and return new data + * @param batchTime The time of the batch to which this RDD belongs to. Use to update */ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]], private var partitionedDataRDD: RDD[(K, V)], - trackingFunction: (K, Option[V], State[S]) => Option[T], - currentTime: Long, timeoutThresholdTime: Option[Long] + trackingFunction: (Time, K, Option[V], State[S]) => Option[T], + batchTime: Time, timeoutThresholdTime: Option[Long] ) extends RDD[TrackStateRDDRecord[K, S, T]]( partitionedDataRDD.sparkContext, List( @@ -111,6 +104,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: val dataIterator = partitionedDataRDD.iterator( stateRDDPartition.partitionedDataRDDPartition, context) + // Create a new state map by cloning the previous one (if it exists) or by creating an empty one val newStateMap = if (prevStateRDDIterator.hasNext) { prevStateRDDIterator.next().stateMap.copy() } else { @@ -118,24 +112,28 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: } val emittedRecords = new ArrayBuffer[T] - val wrappedState = new StateImpl[S]() + // Call the tracking function on each record in the data RDD partition, and accordingly + // update the states touched, and the data returned by the tracking function. dataIterator.foreach { case (key, value) => wrappedState.wrap(newStateMap.get(key)) - val emittedRecord = trackingFunction(key, Some(value), wrappedState) + val emittedRecord = trackingFunction(batchTime, key, Some(value), wrappedState) if (wrappedState.isRemoved) { newStateMap.remove(key) } else if (wrappedState.isUpdated) { - newStateMap.put(key, wrappedState.get(), currentTime) + newStateMap.put(key, wrappedState.get(), batchTime.milliseconds) } emittedRecords ++= emittedRecord } + // If the RDD is expected to be doing a full scan of all the data in the StateMap, + // then use this opportunity to filter out those keys that have timed out. + // For each of them call the tracking function. if (doFullScan && timeoutThresholdTime.isDefined) { newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => wrappedState.wrapTiminoutState(state) - val emittedRecord = trackingFunction(key, None, wrappedState) + val emittedRecord = trackingFunction(batchTime, key, None, wrappedState) emittedRecords ++= emittedRecord newStateMap.remove(key) } @@ -165,17 +163,17 @@ private[streaming] object TrackStateRDD { def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( pairRDD: RDD[(K, S)], partitioner: Partitioner, - updateTime: Long): TrackStateRDD[K, V, S, T] = { + updateTime: Time): TrackStateRDD[K, V, S, T] = { val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => val stateMap = StateMap.create[K, S](SparkEnv.get.conf) - iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime) } + iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) } Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T])) }, preservesPartitioning = true) val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) - val noOpFunc = (key: K, value: Option[V], state: State[S]) => None + val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 337070821eff..ed622ef7bf70 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -63,7 +63,7 @@ private[streaming] object StateMap { } } -/** Specific implementation of SessionStore interface representing an empty map */ +/** Implementation of StateMap interface representing an empty map */ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { override def put(key: K, session: S, updateTime: Long): Unit = { throw new NotImplementedError("put() should not be called on an EmptyStateMap") @@ -76,7 +76,7 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa override def toDebugString(): String = "" } -/** Implementation of StateMap based on Spark's OpenHashMap */ +/** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( @transient @volatile var parentStateMap: StateMap[K, S], initialCapacity: Int = 64, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index 28738c4414aa..48d3b41b66cb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -35,7 +35,7 @@ class StateMapSuite extends SparkFunSuite { assert(map.getByTime(10000).isEmpty) assert(map.getAll().isEmpty) map.remove(1) // no exception - assert(map.copy().eq(this)) + assert(map.copy().eq(map)) } test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index a7e6548d03d7..97d75a6c5d0d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -57,7 +57,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef } test("state - get, exists, update, remove, ") { - val state = new StateImpl[Int]() + var state: StateImpl[Int] = null def testState( expectedData: Option[Int], @@ -69,14 +69,14 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef assert(state.exists) assert(state.get() === expectedData.get) assert(state.getOption() === expectedData) - assert(state.getOrElse(-1) === expectedData.get) // test implicit Option conversion + assert(state.getOption.getOrElse(-1) === expectedData.get) // test implicit Option conversion } else { assert(!state.exists) intercept[NoSuchElementException] { state.get() } assert(state.getOption() === None) - assert(state.getOrElse(-1) === -1) // test implicit Option conversion + assert(state.getOption.getOrElse(-1) === -1) // test implicit Option conversion } assert(state.isTimingOut() === shouldBeTimingOut) @@ -102,6 +102,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef } } + state = new StateImpl[Int]() testState(None) state.wrap(None) @@ -111,17 +112,64 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef testState(Some(1)) state.update(2) - testState(Some(2), shouldBeRemoved = true) + testState(Some(2), shouldBeUpdated = true) + + state = new StateImpl[Int]() + state.update(2) + testState(Some(2), shouldBeUpdated = true) state.remove() - testState(None, shouldBeUpdated = true) + testState(None, shouldBeRemoved = true) state.wrapTiminoutState(3) testState(Some(3), shouldBeTimingOut = true) } + test("trackStateByKey - basic operations with simple API") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq(1), + Seq(2, 1), + Seq(3, 2, 1), + Seq(4, 3), + Seq(5), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + // state maintains running count, and updated count is returned + val trackStateFunc = (value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + state.update(sum) + sum + } + + testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData) + } - test("trackStateByKey - basic operations") { + test("trackStateByKey - basic operations with advanced API") { val inputData = Seq( Seq(), @@ -156,13 +204,13 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef ) // state maintains running count, key string doubled and returned - val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { - val sum = value.getOrElse(0) + state.getOrElse(0) + val trackStateFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) state.update(sum) Some(key * 2) } - testOperation(inputData, StateSpec(trackStateFunc), outputData, stateData) + testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData) } test("trackStateByKey - states as emitted records") { @@ -199,14 +247,14 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef Seq(("a", 5), ("b", 3), ("c", 1)) ) - val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { - val sum = value.getOrElse(0) + state.getOrElse(0) + val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) val output = (key, sum) state.update(sum) Some(output) } - testOperation(inputData, StateSpec(trackStateFunc), outputData, stateData) + testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData) } test("trackStateByKey - initial states, with nothing emitted") { @@ -237,14 +285,14 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)) ) - val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { - val sum = value.getOrElse(0) + state.getOrElse(0) + val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) val output = (key, sum) state.update(sum) None.asInstanceOf[Option[Int]] } - val trackStateSpec = StateSpec(trackStateFunc).initialState(sc.makeRDD(initialState)) + val trackStateSpec = StateSpec.function(trackStateFunc).initialState(sc.makeRDD(initialState)) testOperation(inputData, trackStateSpec, outputData, stateData) } @@ -286,7 +334,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef Seq() ) - val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { + val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { if (state.exists) { state.remove() Some(key) @@ -296,7 +344,8 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef } } - testOperation(inputData, StateSpec(trackStateFunc).numPartitions(1), outputData, stateData) + testOperation( + inputData, StateSpec.function(trackStateFunc).numPartitions(1), outputData, stateData) } test("trackStateByKey - state timing out") { @@ -310,7 +359,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef Seq("a") // a will not time out ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active - val trackStateFunc = (key: String, value: Option[Int], state: State[Int]) => { + val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { if (value.isDefined) { state.update(1) } @@ -322,7 +371,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef } val (collectedOutputs, collectedStateSnapshots) = getOperationOutput( - inputData, StateSpec(trackStateFunc).timeout(Seconds(3)), 20) + inputData, StateSpec.function(trackStateFunc).timeout(Seconds(3)), 20) // b and c should be emitted once each, when they were marked as expired assert(collectedOutputs.flatten.sorted === Seq("b", "c")) @@ -379,13 +428,14 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef } private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) { + val debugString = "\nExpected:\n" + expected.mkString("\n") + + "\nCollected:\n" + collected.mkString("\n") assert(expected.size === collected.size, - s"number of collected $typ (${collected.size}) different from expected (${expected.size})") + s"number of collected $typ (${collected.size}) different from expected (${expected.size})" + + debugString) expected.zip(collected).foreach { case (c, e) => assert(c.toSet === e.toSet, - s"collected $typ is different from expected" + - "\nExpected:\n" + expected.mkString("\n") + - "\nCollected:\n" + collected.mkString("\n") + s"collected $typ is different from expected $debugString" ) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index 0a4d0070dfaf..fc5f26607ef9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.State +import org.apache.spark.streaming.{Time, State} import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite} class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -39,7 +39,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { val data = Seq((1, "1"), (2, "2"), (3, "3")) val partitioner = new HashPartitioner(10) val rdd = TrackStateRDD.createFromPairRDD[Int, Int, String, Int]( - sc.parallelize(data), partitioner, 123) + sc.parallelize(data), partitioner, Time(123)) assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty) assert(rdd.partitions.size === partitioner.numPartitions) @@ -52,7 +52,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet val partitioner = new HashPartitioner(2) val initStateRDD = TrackStateRDD.createFromPairRDD[String, Int, Int, Int]( - sc.parallelize(initStates), partitioner, initTime).persist() + sc.parallelize(initStates), partitioner, Time(initTime)).persist() assertRDD(initStateRDD, initStateWthTime, Set.empty) val updateTime = 345 @@ -73,7 +73,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { // To track which keys are being touched TrackStateRDDSuite.touchedStateKeys.clear() - val trackingFunc = (key: String, data: Option[Int], state: State[Int]) => { + val trackingFunc = (time: Time, key: String, data: Option[Int], state: State[Int]) => { // Track the key that has been touched TrackStateRDDSuite.touchedStateKeys += key @@ -151,7 +151,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( testStateRDD: TrackStateRDD[K, V, S, T], newDataRDD: RDD[(K, V)], - trackStateFunc: (K, Option[V], State[S]) => Option[T], + trackStateFunc: (Time, K, Option[V], State[S]) => Option[T], currentTime: Long, expectedStates: Set[(K, S, Int)], expectedEmittedRecords: Set[T], @@ -165,7 +165,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { } val newStateRDD = new TrackStateRDD[K, V, S, T]( - testStateRDD, newDataRDD, trackStateFunc, currentTime, None) + testStateRDD, newDataRDD, trackStateFunc, Time(currentTime), None) if (doFullScan) newStateRDD.setFullScan() // Persist to make sure that it gets computed only once and we can track precisely how many From f1a669653811d6558b692fc521a0a7f29972ba33 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 10 Nov 2015 12:02:20 -0800 Subject: [PATCH 24/26] Style fix --- .../main/scala/org/apache/spark/streaming/StateSpec.scala | 6 +++--- .../org/apache/spark/streaming/TrackStateByKeySuite.scala | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index 0896f57c12bc..728f304730c4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -60,15 +60,15 @@ import org.apache.spark.{HashPartitioner, Partitioner} @Experimental sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable { - /** Set the RDD containing the initial states that will be used by `trackStateByKey`*/ + /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ def initialState(rdd: RDD[(KeyType, StateType)]): this.type - /** Set the RDD containing the initial states that will be used by `trackStateByKey`*/ + /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type /** * Set the number of partitions by which the state RDDs generated by `trackStateByKey` - * will be partitioned. Hash partitioning will be used on the + * will be partitioned. Hash partitioning will be used. */ def numPartitions(numPartitions: Int): this.type diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index 97d75a6c5d0d..15a85fc5cd21 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -69,14 +69,14 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef assert(state.exists) assert(state.get() === expectedData.get) assert(state.getOption() === expectedData) - assert(state.getOption.getOrElse(-1) === expectedData.get) // test implicit Option conversion + assert(state.getOption.getOrElse(-1) === expectedData.get) } else { assert(!state.exists) intercept[NoSuchElementException] { state.get() } assert(state.getOption() === None) - assert(state.getOption.getOrElse(-1) === -1) // test implicit Option conversion + assert(state.getOption.getOrElse(-1) === -1) } assert(state.isTimingOut() === shouldBeTimingOut) From 77c9a66e911adf74014bd9b16fb26153a445d372 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 10 Nov 2015 13:22:54 -0800 Subject: [PATCH 25/26] fix build --- .../examples/streaming/StatefulNetworkWordCount.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index f0b6ead255e0..be2ae0b47336 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -59,17 +59,16 @@ object StatefulNetworkWordCount { val wordDstream = words.map(x => (x, 1)) // Update the cumulative count using updateStateByKey - // This will give a Dstream made of state (which is the cumulative count of the words) - - val trackStateFunc = (word: String, one: Option[Int], state: State[Int]) => { - val sum = one.getOrElse(0) + state.getOrElse(0) + // This will give a DStream made of state (which is the cumulative count of the words) + val trackStateFunc = (batchTime: Time, word: String, one: Option[Int], state: State[Int]) => { + val sum = one.getOrElse(0) + state.getOption.getOrElse(0) val output = (word, sum) state.update(sum) Some(output) } val stateDstream = wordDstream.trackStateByKey( - StateSpec(trackStateFunc).initialState(initialRDD)) + StateSpec.function(trackStateFunc).initialState(initialRDD)) stateDstream.print() ssc.start() ssc.awaitTermination() From ae64786fd937002a2cc1f80518d54e970a6bbb21 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 10 Nov 2015 16:34:38 -0800 Subject: [PATCH 26/26] Addressed type issue in StateSpec.function --- .../apache/spark/streaming/StateSpec.scala | 6 +-- .../dstream/PairDStreamFunctions.scala | 3 +- .../streaming/dstream/TrackStateDStream.scala | 12 ++++- .../streaming/TrackStateByKeySuite.scala | 53 ++++++++++++++++++- 4 files changed, 67 insertions(+), 7 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index 728f304730c4..c9fe35e74c1c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -150,15 +150,15 @@ object StateSpec { * @tparam StateType Class of the states data * @tparam EmittedType Class of the emitted data */ - def function[ValueType, StateType, EmittedType]( + def function[KeyType, ValueType, StateType, EmittedType]( trackingFunction: (Option[ValueType], State[StateType]) => EmittedType - ): StateSpec[Any, ValueType, StateType, EmittedType] = { + ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { ClosureCleaner.clean(trackingFunction, checkSerializable = true) val wrappedFunction = (time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => { Some(trackingFunction(value, state)) } - new StateSpecImpl[Any, ValueType, StateType, EmittedType](wrappedFunction) + new StateSpecImpl(wrappedFunction) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 9bdffa0c2d7e..fb691eed27e3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -380,7 +380,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) */ @Experimental def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag]( - spec: StateSpec[K, V, StateType, EmittedType]): TrackStateDStream[K, StateType, EmittedType] = { + spec: StateSpec[K, V, StateType, EmittedType] + ): TrackStateDStream[K, V, StateType, EmittedType] = { new TrackStateDStreamImpl[K, V, StateType, EmittedType]( self, spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala index 701ccab3562a..58d89c93bcbe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} * @tparam EmittedType Class of the emitted records */ @Experimental -sealed abstract class TrackStateDStream[KeyType, StateType, EmittedType: ClassTag]( +sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, EmittedType: ClassTag]( ssc: StreamingContext) extends DStream[EmittedType](ssc) { /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ @@ -51,7 +51,7 @@ private[streaming] class TrackStateDStreamImpl[ KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag]( dataStream: DStream[(KeyType, ValueType)], spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType]) - extends TrackStateDStream[KeyType, StateType, EmittedType](dataStream.context) { + extends TrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream.context) { private val internalStream = new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec) @@ -78,6 +78,14 @@ private[streaming] class TrackStateDStreamImpl[ internalStream.flatMap { _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable } } + + def keyClass: Class[_] = implicitly[ClassTag[KeyType]].runtimeClass + + def valueClass: Class[_] = implicitly[ClassTag[ValueType]].runtimeClass + + def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass + + def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index 15a85fc5cd21..e3072b444284 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.streaming.dstream.{TrackStateDStream, TrackStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} @@ -166,7 +167,8 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef sum } - testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData) + testOperation[String, Int, Int]( + inputData, StateSpec.function(trackStateFunc), outputData, stateData) } test("trackStateByKey - basic operations with advanced API") { @@ -213,6 +215,55 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData) } + test("trackStateByKey - type inferencing and class tags") { + + // Simple track state function with value as Int, state as Double and emitted type as Double + val simpleFunc = (value: Option[Int], state: State[Double]) => { + 0L + } + + // Advanced track state function with key as String, value as Int, state as Double and + // emitted type as Double + val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => { + Some(0L) + } + + def testTypes(dstream: TrackStateDStream[_, _, _, _]): Unit = { + val dstreamImpl = dstream.asInstanceOf[TrackStateDStreamImpl[_, _, _, _]] + assert(dstreamImpl.keyClass === classOf[String]) + assert(dstreamImpl.valueClass === classOf[Int]) + assert(dstreamImpl.stateClass === classOf[Double]) + assert(dstreamImpl.emittedClass === classOf[Long]) + } + + val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2) + + // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types + val simpleFunctionStateStream1 = inputStream.trackStateByKey( + StateSpec.function(simpleFunc).numPartitions(1)) + testTypes(simpleFunctionStateStream1) + + // Separately defining StateSpec with simple function requires explicitly specifying types + val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc) + val simpleFunctionStateStream2 = inputStream.trackStateByKey(simpleFuncSpec) + testTypes(simpleFunctionStateStream2) + + // Separately defining StateSpec with advanced function implicitly gets the types + val advFuncSpec1 = StateSpec.function(advancedFunc) + val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1) + testTypes(advFunctionStateStream1) + + // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types + val advFunctionStateStream2 = inputStream.trackStateByKey( + StateSpec.function(simpleFunc).numPartitions(1)) + testTypes(advFunctionStateStream2) + + // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types + val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc) + val advFunctionStateStream3 = inputStream.trackStateByKey[Double, Long](advFuncSpec2) + testTypes(advFunctionStateStream3) + } + test("trackStateByKey - states as emitted records") { val inputData = Seq(