@@ -23,14 +23,15 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
2323import scala .deprecated
2424import scala .collection .mutable .HashMap
2525import scala .reflect .ClassTag
26+ import scala .util .matching .Regex
2627
2728import org .apache .spark .{Logging , SparkException }
2829import org .apache .spark .rdd .{BlockRDD , RDD }
2930import org .apache .spark .storage .StorageLevel
3031import org .apache .spark .streaming ._
3132import org .apache .spark .streaming .StreamingContext ._
3233import org .apache .spark .streaming .scheduler .Job
33- import org .apache .spark .util .{CallSite , Utils , MetadataCleaner }
34+ import org .apache .spark .util .{CallSite , MetadataCleaner }
3435
3536/**
3637 * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
@@ -106,20 +107,8 @@ abstract class DStream[T: ClassTag] (
106107 /** Return the StreamingContext associated with this DStream */
107108 def context = ssc
108109
109- /* Find the creation callSite */
110- val creationSite = Utils .getCallSite(org.apache.spark.streaming.util.Utils .streamingRegexFunc)
111-
112- /* Store the RDD creation callSite in threadlocal */
113- private def setRDDCreationCallSite (callSite : CallSite = creationSite) = {
114- ssc.sparkContext.setLocalProperty(Utils .CALL_SITE_SHORT , callSite.shortForm)
115- ssc.sparkContext.setLocalProperty(Utils .CALL_SITE_LONG , callSite.longForm)
116- }
117-
118- /* Return the current callSite */
119- private def getRDDCreationCallSite (): CallSite = {
120- CallSite (ssc.sparkContext.getLocalProperty(Utils .CALL_SITE_SHORT ),
121- ssc.sparkContext.getLocalProperty(Utils .CALL_SITE_LONG ))
122- }
110+ /* Set the creation call site */
111+ private [streaming] val creationSite = DStream .getCallSite()
123112
124113 /** Persist the RDDs of this DStream with the given storage level */
125114 def persist (level : StorageLevel ): DStream [T ] = {
@@ -264,6 +253,16 @@ abstract class DStream[T: ClassTag] (
264253 dependencies.foreach(_.setGraph(graph))
265254 }
266255
256+ /* Set the custom RDD creation site as this thread's local property. */
257+ private def setRDDCreationSite (creationSite : CallSite ): Unit = {
258+
259+ }
260+
261+ /* Get the custom RDD creation site set as this thread's local property. */
262+ private def getRDDCreationSite (): CallSite = {
263+ ssc.sparkContext.getCallSite()
264+ }
265+
267266 private [streaming] def remember (duration : Duration ) {
268267 if (duration != null && duration > rememberDuration) {
269268 rememberDuration = duration
@@ -287,47 +286,41 @@ abstract class DStream[T: ClassTag] (
287286 }
288287
289288 /**
290- * Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal
291- * method that should not be called directly .
289+ * Get the RDD corresponding to the given time; either retrieve it from cache
290+ * or compute-and-cache it .
292291 */
293292 private [streaming] def getOrCompute (time : Time ): Option [RDD [T ]] = {
294- // If this DStream was not initialized (i.e., zeroTime not set), then do it
295- // If RDD was already generated, then retrieve it from HashMap
296- generatedRDDs.get(time) match {
297-
298- // If an RDD was already generated and is being reused, then
299- // probably all RDDs in this DStream will be reused and hence should be cached
300- case Some (oldRDD) => Some (oldRDD)
301-
302- // if RDD was not generated, and if the time is valid
303- // (based on sliding time of this DStream), then generate the RDD
304- case None => {
305- if (isTimeValid(time)) {
306- val prevCallSite = getRDDCreationCallSite
307- setRDDCreationCallSite()
308- val rddOption = compute(time) match {
309- case Some (newRDD) =>
310- if (storageLevel != StorageLevel .NONE ) {
311- newRDD.persist(storageLevel)
312- logInfo(" Persisting RDD " + newRDD.id + " for time " +
313- time + " to " + storageLevel + " at time " + time)
314- }
315- if (checkpointDuration != null &&
316- (time - zeroTime).isMultipleOf(checkpointDuration)) {
317- newRDD.checkpoint()
318- logInfo(" Marking RDD " + newRDD.id + " for time " + time +
319- " for checkpointing at time " + time)
320- }
321- generatedRDDs.put(time, newRDD)
322- Some (newRDD)
323- case None =>
324- return None
293+ // If RDD was already generated, then retrieve it from HashMap,
294+ // or else compute the RDD
295+ generatedRDDs.get(time).orElse {
296+ // Compute the RDD if time is valid (e.g. correct time in a sliding window)
297+ // of RDD generation, else generate nothing.
298+ if (isTimeValid(time)) {
299+ // Set the thread-local property for call sites to this DStream's creation site
300+ // such that RDDs generated by compute gets that as their creation site.
301+ // Note that this `getOrCompute` may get called from another DStream which may have
302+ // set its own call site. So we store its call site in a temporary variable,
303+ // set this DStream's creation site, generate RDDs and then restore the previous call site.
304+ val prevCallSite = ssc.sparkContext.getCallSite()
305+ ssc.sparkContext.setCallSite(creationSite)
306+ val rddOption = compute(time)
307+ ssc.sparkContext.setCallSite(prevCallSite)
308+
309+ rddOption.foreach { case newRDD =>
310+ // Register the generated RDD for caching and checkpointing
311+ if (storageLevel != StorageLevel .NONE ) {
312+ newRDD.persist(storageLevel)
313+ logDebug(s " Persisting RDD ${newRDD.id} for time $time to $storageLevel" )
325314 }
326- setRDDCreationCallSite(prevCallSite)
327- return rddOption
328- } else {
329- return None
315+ if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) {
316+ newRDD.checkpoint()
317+ logInfo(s " Marking RDD ${newRDD.id} for time $time for checkpointing " )
318+ }
319+ generatedRDDs.put(time, newRDD)
330320 }
321+ rddOption
322+ } else {
323+ None
331324 }
332325 }
333326 }
@@ -818,3 +811,29 @@ abstract class DStream[T: ClassTag] (
818811 this
819812 }
820813}
814+
815+ private [streaming] object DStream {
816+
817+ /** Get the creation site of a DStream from the stack trace of when the DStream is created. */
818+ def getCallSite (): CallSite = {
819+ val SPARK_CLASS_REGEX = """ ^org\.apache\.spark""" .r
820+ val SPARK_STREAMING_TESTCLASS_REGEX = """ ^org\.apache\.spark\.streaming\.test""" .r
821+ val SPARK_EXAMPLES_CLASS_REGEX = """ ^org\.apache\.spark\.examples""" .r
822+ val SCALA_CLASS_REGEX = """ ^scala""" .r
823+
824+ /** Filtering function that returns true for classes that belong to a streaming application */
825+ def streamingClassFilterFunc (className : String ): Boolean = {
826+ def doesMatch (r : Regex ) = r.findFirstIn(className).isDefined
827+ val isSparkClass = doesMatch(SPARK_CLASS_REGEX )
828+ val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX )
829+ val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX )
830+ val isScalaClass = doesMatch(SCALA_CLASS_REGEX )
831+
832+ // If the class is a spark example class or a streaming test class then it is considered
833+ // as a streaming application class. Otherwise, consider any non-Spark and non-Scala class
834+ // as streaming application class.
835+ isSparkExampleClass || isSparkStreamingTestClass || ! (isSparkClass || isScalaClass)
836+ }
837+ org.apache.spark.util.Utils .getCallSite(streamingClassFilterFunc)
838+ }
839+ }
0 commit comments