Skip to content

Commit dc54c71

Browse files
committed
Made changes based on PR comments.
1 parent 390b45d commit dc54c71

File tree

3 files changed

+16
-17
lines changed

3 files changed

+16
-17
lines changed

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -865,26 +865,25 @@ private[spark] object Utils extends Logging {
865865
}
866866

867867
/** Default filtering function for finding call sites using `getCallSite`. */
868-
private def defaultCallSiteFilterFunc(className: String): Boolean = {
868+
private def coreExclusionFunction(className: String): Boolean = {
869869
// A regular expression to match classes of the "core" Spark API that we want to skip when
870870
// finding the call site of a method.
871871
val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
872872
val SCALA_CLASS_REGEX = """^scala""".r
873-
val isSparkClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined
873+
val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined
874874
val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined
875-
// If the class neither belongs to Spark nor is a simple Scala class, then it is a
876-
// user-defined class
877-
!isSparkClass && !isScalaClass
875+
// If the class is a Spark internal class or a Scala class, then exclude.
876+
isSparkCoreClass || isScalaClass
878877
}
879878

880879
/**
881880
* When called inside a class in the spark package, returns the name of the user code class
882881
* (outside the spark package) that called into Spark, as well as which Spark method they called.
883882
* This is used, for example, to tell users where in their code each RDD got created.
884883
*
885-
* @param classFilterFunc Function that returns true if the given class belongs to user code
884+
* @param skipClass Function that is used to exclude non-user-code classes.
886885
*/
887-
def getCallSite(classFilterFunc: String => Boolean = defaultCallSiteFilterFunc): CallSite = {
886+
def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = {
888887
val trace = Thread.currentThread.getStackTrace()
889888
.filterNot { ste:StackTraceElement =>
890889
// When running under some profilers, the current stack trace might contain some bogus
@@ -905,7 +904,7 @@ private[spark] object Utils extends Logging {
905904

906905
for (el <- trace) {
907906
if (insideSpark) {
908-
if (!classFilterFunc(el.getClassName)) {
907+
if (skipClass(el.getClassName)) {
909908
lastSparkMethod = if (el.getMethodName == "<init>") {
910909
// Spark method is a constructor; get its class name
911910
el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1)

streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ class StreamingContext private[streaming] (
447447
throw new SparkException("StreamingContext has already been stopped")
448448
}
449449
validate()
450-
sparkContext.setCallSite(DStream.getCallSite())
450+
sparkContext.setCallSite(DStream.getCreationSite())
451451
scheduler.start()
452452
state = Started
453453
}

streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ abstract class DStream[T: ClassTag] (
108108
def context = ssc
109109

110110
/* Set the creation call site */
111-
private[streaming] val creationSite = DStream.getCallSite()
111+
private[streaming] val creationSite = DStream.getCreationSite()
112112

113113
/** Persist the RDDs of this DStream with the given storage level */
114114
def persist(level: StorageLevel): DStream[T] = {
@@ -805,25 +805,25 @@ abstract class DStream[T: ClassTag] (
805805
private[streaming] object DStream {
806806

807807
/** Get the creation site of a DStream from the stack trace of when the DStream is created. */
808-
def getCallSite(): CallSite = {
808+
def getCreationSite(): CallSite = {
809809
val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r
810810
val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r
811811
val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r
812812
val SCALA_CLASS_REGEX = """^scala""".r
813813

814-
/** Filtering function that returns true for classes that belong to a streaming application */
815-
def streamingClassFilterFunc(className: String): Boolean = {
814+
/** Filtering function that excludes non-user classes for a streaming application */
815+
def streamingExclustionFunction(className: String): Boolean = {
816816
def doesMatch(r: Regex) = r.findFirstIn(className).isDefined
817817
val isSparkClass = doesMatch(SPARK_CLASS_REGEX)
818818
val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX)
819819
val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX)
820820
val isScalaClass = doesMatch(SCALA_CLASS_REGEX)
821821

822822
// If the class is a spark example class or a streaming test class then it is considered
823-
// as a streaming application class. Otherwise, consider any non-Spark and non-Scala class
824-
// as streaming application class.
825-
isSparkExampleClass || isSparkStreamingTestClass || !(isSparkClass || isScalaClass)
823+
// as a streaming application class and don't exclude. Otherwise, exclude any
824+
// non-Spark and non-Scala class, as the rest would streaming application classes.
825+
(isSparkClass || isScalaClass) && !isSparkExampleClass && !isSparkStreamingTestClass
826826
}
827-
org.apache.spark.util.Utils.getCallSite(streamingClassFilterFunc)
827+
org.apache.spark.util.Utils.getCallSite(streamingExclustionFunction)
828828
}
829829
}

0 commit comments

Comments
 (0)