Skip to content

Commit 7baa427

Browse files
committed
Refactored getCallSite and setCallSite to make it simpler. Also added unit test for DStream creation site.
1 parent b9ed945 commit 7baa427

File tree

7 files changed

+161
-127
lines changed

7 files changed

+161
-127
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,28 +1024,40 @@ class SparkContext(config: SparkConf) extends Logging {
10241024
}
10251025

10261026
/**
1027-
* Support function for API backtraces.
1027+
* Set the thread-local property for overriding the call sites
1028+
* of actions and RDDs.
10281029
*/
1029-
def setCallSite(site: String) {
1030-
setLocalProperty("externalCallSite", site)
1030+
def setCallSite(shortCallSite: String) {
1031+
setLocalProperty(CallSite.SHORT_FORM, shortCallSite)
10311032
}
10321033

10331034
/**
1034-
* Support function for API backtraces.
1035+
* Set the thread-local property for overriding the call sites
1036+
* of actions and RDDs.
1037+
*/
1038+
private[spark] def setCallSite(callSite: CallSite) {
1039+
setLocalProperty(CallSite.SHORT_FORM, callSite.shortForm)
1040+
setLocalProperty(CallSite.LONG_FORM, callSite.longForm)
1041+
}
1042+
1043+
/**
1044+
* Clear the thread-local property for overriding the call sites
1045+
* of actions and RDDs.
10351046
*/
10361047
def clearCallSite() {
1037-
setLocalProperty("externalCallSite", null)
1048+
setLocalProperty(CallSite.SHORT_FORM, null)
1049+
setLocalProperty(CallSite.LONG_FORM, null)
10381050
}
10391051

10401052
/**
10411053
* Capture the current user callsite and return a formatted version for printing. If the user
1042-
* has overridden the call site, this will return the user's version.
1054+
* has overridden the call site using `setCallSite()`, this will return the user's version.
10431055
*/
10441056
private[spark] def getCallSite(): CallSite = {
1045-
Option(getLocalProperty("externalCallSite")) match {
1046-
case Some(callSite) => CallSite(callSite, longForm = "")
1047-
case None => Utils.getCallSite()
1048-
}
1057+
Option(getLocalProperty(CallSite.SHORT_FORM)).map { case shortCallSite =>
1058+
val longCallSite = Option(getLocalProperty(CallSite.LONG_FORM)).getOrElse("")
1059+
CallSite(shortCallSite, longCallSite)
1060+
}.getOrElse(Utils.getCallSite())
10491061
}
10501062

10511063
/**

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,14 +1220,7 @@ abstract class RDD[T: ClassTag](
12201220
private var storageLevel: StorageLevel = StorageLevel.NONE
12211221

12221222
/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
1223-
@transient private[spark] val creationSite = {
1224-
val short: String = sc.getLocalProperty(Utils.CALL_SITE_SHORT)
1225-
if (short != null) {
1226-
CallSite(short, sc.getLocalProperty(Utils.CALL_SITE_LONG))
1227-
} else {
1228-
Utils.getCallSite()
1229-
}
1230-
}
1223+
@transient private[spark] val creationSite = sc.getCallSite()
12311224

12321225
private[spark] def getCreationSite: String = Option(creationSite).map(_.shortForm).getOrElse("")
12331226

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,17 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream,
4747
/** CallSite represents a place in user code. It can have a short and a long form. */
4848
private[spark] case class CallSite(shortForm: String, longForm: String)
4949

50+
private[spark] object CallSite {
51+
val SHORT_FORM = "callSite.short"
52+
val LONG_FORM = "callSite.long"
53+
}
54+
5055
/**
5156
* Various utility methods used by Spark.
5257
*/
5358
private[spark] object Utils extends Logging {
5459
val random = new Random()
5560

56-
private[spark] val CALL_SITE_SHORT: String = "callSite.short"
57-
private[spark] val CALL_SITE_LONG: String = "callSite.long"
58-
5961
/** Serialize an object using Java serialization */
6062
def serialize[T](o: T): Array[Byte] = {
6163
val bos = new ByteArrayOutputStream()
@@ -854,24 +856,27 @@ private[spark] object Utils extends Logging {
854856
}
855857
}
856858

857-
/**
858-
* A regular expression to match classes of the "core" Spark API that we want to skip when
859-
* finding the call site of a method.
860-
*/
861-
private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
862-
val SCALA_CLASS_REGEX = """^scala""".r
863-
864-
private def defaultRegexFunc(className: String): Boolean = {
865-
SPARK_CLASS_REGEX.findFirstIn(className).isDefined ||
866-
SCALA_CLASS_REGEX.findFirstIn(className).isDefined
859+
/** Default filtering function for finding call sites using `getCallSite`. */
860+
private def defaultCallSiteFilterFunc(className: String): Boolean = {
861+
// A regular expression to match classes of the "core" Spark API that we want to skip when
862+
// finding the call site of a method.
863+
val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
864+
val SCALA_CLASS_REGEX = """^scala""".r
865+
val isSparkClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined
866+
val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined
867+
// If the class neither belongs to Spark nor is a simple Scala class, then it is a
868+
// user-defined class
869+
!isSparkClass && !isScalaClass
867870
}
868871

869872
/**
870873
* When called inside a class in the spark package, returns the name of the user code class
871874
* (outside the spark package) that called into Spark, as well as which Spark method they called.
872875
* This is used, for example, to tell users where in their code each RDD got created.
876+
*
877+
* @param classFilterFunc Function that returns true if the given class belongs to user code
873878
*/
874-
def getCallSite(regexFunc: String => Boolean = defaultRegexFunc(_)): CallSite = {
879+
def getCallSite(classFilterFunc: String => Boolean = defaultCallSiteFilterFunc): CallSite = {
875880
val trace = Thread.currentThread.getStackTrace()
876881
.filterNot { ste:StackTraceElement =>
877882
// When running under some profilers, the current stack trace might contain some bogus
@@ -892,7 +897,7 @@ private[spark] object Utils extends Logging {
892897

893898
for (el <- trace) {
894899
if (insideSpark) {
895-
if (regexFunc(el.getClassName)) {
900+
if (!classFilterFunc(el.getClassName)) {
896901
lastSparkMethod = if (el.getMethodName == "<init>") {
897902
// Spark method is a constructor; get its class name
898903
el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1)

streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,7 @@ class StreamingContext private[streaming] (
441441
throw new SparkException("StreamingContext has already been stopped")
442442
}
443443
validate()
444-
sc.setCallSite(
445-
Utils.getCallSite(org.apache.spark.streaming.util.Utils.streamingRegexFunc).shortForm
446-
)
444+
sparkContext.setCallSite(DStream.getCallSite())
447445
scheduler.start()
448446
state = Started
449447
}

streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala

Lines changed: 71 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
2323
import scala.deprecated
2424
import scala.collection.mutable.HashMap
2525
import scala.reflect.ClassTag
26+
import scala.util.matching.Regex
2627

2728
import org.apache.spark.{Logging, SparkException}
2829
import org.apache.spark.rdd.{BlockRDD, RDD}
2930
import org.apache.spark.storage.StorageLevel
3031
import org.apache.spark.streaming._
3132
import org.apache.spark.streaming.StreamingContext._
3233
import org.apache.spark.streaming.scheduler.Job
33-
import org.apache.spark.util.{CallSite, Utils, MetadataCleaner}
34+
import org.apache.spark.util.{CallSite, MetadataCleaner}
3435

3536
/**
3637
* A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
@@ -106,20 +107,8 @@ abstract class DStream[T: ClassTag] (
106107
/** Return the StreamingContext associated with this DStream */
107108
def context = ssc
108109

109-
/* Find the creation callSite */
110-
val creationSite = Utils.getCallSite(org.apache.spark.streaming.util.Utils.streamingRegexFunc)
111-
112-
/* Store the RDD creation callSite in threadlocal */
113-
private def setRDDCreationCallSite(callSite: CallSite = creationSite) = {
114-
ssc.sparkContext.setLocalProperty(Utils.CALL_SITE_SHORT, callSite.shortForm)
115-
ssc.sparkContext.setLocalProperty(Utils.CALL_SITE_LONG, callSite.longForm)
116-
}
117-
118-
/* Return the current callSite */
119-
private def getRDDCreationCallSite(): CallSite = {
120-
CallSite(ssc.sparkContext.getLocalProperty(Utils.CALL_SITE_SHORT),
121-
ssc.sparkContext.getLocalProperty(Utils.CALL_SITE_LONG))
122-
}
110+
/* Set the creation call site */
111+
private[streaming] val creationSite = DStream.getCallSite()
123112

124113
/** Persist the RDDs of this DStream with the given storage level */
125114
def persist(level: StorageLevel): DStream[T] = {
@@ -264,6 +253,16 @@ abstract class DStream[T: ClassTag] (
264253
dependencies.foreach(_.setGraph(graph))
265254
}
266255

256+
/* Set the custom RDD creation site as this thread's local property. */
257+
private def setRDDCreationSite(creationSite: CallSite): Unit = {
258+
259+
}
260+
261+
/* Get the custom RDD creation site set as this thread's local property. */
262+
private def getRDDCreationSite(): CallSite = {
263+
ssc.sparkContext.getCallSite()
264+
}
265+
267266
private[streaming] def remember(duration: Duration) {
268267
if (duration != null && duration > rememberDuration) {
269268
rememberDuration = duration
@@ -287,47 +286,41 @@ abstract class DStream[T: ClassTag] (
287286
}
288287

289288
/**
290-
* Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal
291-
* method that should not be called directly.
289+
* Get the RDD corresponding to the given time; either retrieve it from cache
290+
* or compute-and-cache it.
292291
*/
293292
private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = {
294-
// If this DStream was not initialized (i.e., zeroTime not set), then do it
295-
// If RDD was already generated, then retrieve it from HashMap
296-
generatedRDDs.get(time) match {
297-
298-
// If an RDD was already generated and is being reused, then
299-
// probably all RDDs in this DStream will be reused and hence should be cached
300-
case Some(oldRDD) => Some(oldRDD)
301-
302-
// if RDD was not generated, and if the time is valid
303-
// (based on sliding time of this DStream), then generate the RDD
304-
case None => {
305-
if (isTimeValid(time)) {
306-
val prevCallSite = getRDDCreationCallSite
307-
setRDDCreationCallSite()
308-
val rddOption = compute(time) match {
309-
case Some(newRDD) =>
310-
if (storageLevel != StorageLevel.NONE) {
311-
newRDD.persist(storageLevel)
312-
logInfo("Persisting RDD " + newRDD.id + " for time " +
313-
time + " to " + storageLevel + " at time " + time)
314-
}
315-
if (checkpointDuration != null &&
316-
(time - zeroTime).isMultipleOf(checkpointDuration)) {
317-
newRDD.checkpoint()
318-
logInfo("Marking RDD " + newRDD.id + " for time " + time +
319-
" for checkpointing at time " + time)
320-
}
321-
generatedRDDs.put(time, newRDD)
322-
Some(newRDD)
323-
case None =>
324-
return None
293+
// If RDD was already generated, then retrieve it from HashMap,
294+
// or else compute the RDD
295+
generatedRDDs.get(time).orElse {
296+
// Compute the RDD if time is valid (e.g. correct time in a sliding window)
297+
// of RDD generation, else generate nothing.
298+
if (isTimeValid(time)) {
299+
// Set the thread-local property for call sites to this DStream's creation site
300+
// such that RDDs generated by compute gets that as their creation site.
301+
// Note that this `getOrCompute` may get called from another DStream which may have
302+
// set its own call site. So we store its call site in a temporary variable,
303+
// set this DStream's creation site, generate RDDs and then restore the previous call site.
304+
val prevCallSite = ssc.sparkContext.getCallSite()
305+
ssc.sparkContext.setCallSite(creationSite)
306+
val rddOption = compute(time)
307+
ssc.sparkContext.setCallSite(prevCallSite)
308+
309+
rddOption.foreach { case newRDD =>
310+
// Register the generated RDD for caching and checkpointing
311+
if (storageLevel != StorageLevel.NONE) {
312+
newRDD.persist(storageLevel)
313+
logDebug(s"Persisting RDD ${newRDD.id} for time $time to $storageLevel")
325314
}
326-
setRDDCreationCallSite(prevCallSite)
327-
return rddOption
328-
} else {
329-
return None
315+
if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) {
316+
newRDD.checkpoint()
317+
logInfo(s"Marking RDD ${newRDD.id} for time $time for checkpointing")
318+
}
319+
generatedRDDs.put(time, newRDD)
330320
}
321+
rddOption
322+
} else {
323+
None
331324
}
332325
}
333326
}
@@ -818,3 +811,29 @@ abstract class DStream[T: ClassTag] (
818811
this
819812
}
820813
}
814+
815+
private[streaming] object DStream {
816+
817+
/** Get the creation site of a DStream from the stack trace of when the DStream is created. */
818+
def getCallSite(): CallSite = {
819+
val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r
820+
val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r
821+
val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r
822+
val SCALA_CLASS_REGEX = """^scala""".r
823+
824+
/** Filtering function that returns true for classes that belong to a streaming application */
825+
def streamingClassFilterFunc(className: String): Boolean = {
826+
def doesMatch(r: Regex) = r.findFirstIn(className).isDefined
827+
val isSparkClass = doesMatch(SPARK_CLASS_REGEX)
828+
val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX)
829+
val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX)
830+
val isScalaClass = doesMatch(SCALA_CLASS_REGEX)
831+
832+
// If the class is a spark example class or a streaming test class then it is considered
833+
// as a streaming application class. Otherwise, consider any non-Spark and non-Scala class
834+
// as streaming application class.
835+
isSparkExampleClass || isSparkStreamingTestClass || !(isSparkClass || isScalaClass)
836+
}
837+
org.apache.spark.util.Utils.getCallSite(streamingClassFilterFunc)
838+
}
839+
}

streaming/src/main/scala/org/apache/spark/streaming/util/Utils.scala

Lines changed: 0 additions & 34 deletions
This file was deleted.

0 commit comments

Comments
 (0)