@@ -49,6 +49,11 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream,
4949/** CallSite represents a place in user code. It can have a short and a long form. */
5050private [spark] case class CallSite (shortForm : String , longForm : String )
5151
52+ private [spark] object CallSite {
53+ val SHORT_FORM = " callSite.short"
54+ val LONG_FORM = " callSite.long"
55+ }
56+
5257/**
5358 * Various utility methods used by Spark.
5459 */
@@ -859,18 +864,26 @@ private[spark] object Utils extends Logging {
859864 }
860865 }
861866
862- /**
863- * A regular expression to match classes of the "core" Spark API that we want to skip when
864- * finding the call site of a method.
865- */
866- private val SPARK_CLASS_REGEX = """ ^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""" .r
867+ /** Default filtering function for finding call sites using `getCallSite`. */
868+ private def coreExclusionFunction (className : String ): Boolean = {
869+ // A regular expression to match classes of the "core" Spark API that we want to skip when
870+ // finding the call site of a method.
871+ val SPARK_CORE_CLASS_REGEX = """ ^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""" .r
872+ val SCALA_CLASS_REGEX = """ ^scala""" .r
873+ val isSparkCoreClass = SPARK_CORE_CLASS_REGEX .findFirstIn(className).isDefined
874+ val isScalaClass = SCALA_CLASS_REGEX .findFirstIn(className).isDefined
875+ // If the class is a Spark internal class or a Scala class, then exclude.
876+ isSparkCoreClass || isScalaClass
877+ }
867878
868879 /**
869880 * When called inside a class in the spark package, returns the name of the user code class
870881 * (outside the spark package) that called into Spark, as well as which Spark method they called.
871882 * This is used, for example, to tell users where in their code each RDD got created.
883+ *
884+ * @param skipClass Function that is used to exclude non-user-code classes.
872885 */
873- def getCallSite : CallSite = {
886+ def getCallSite ( skipClass : String => Boolean = coreExclusionFunction) : CallSite = {
874887 val trace = Thread .currentThread.getStackTrace()
875888 .filterNot { ste: StackTraceElement =>
876889 // When running under some profilers, the current stack trace might contain some bogus
@@ -891,7 +904,7 @@ private[spark] object Utils extends Logging {
891904
892905 for (el <- trace) {
893906 if (insideSpark) {
894- if (SPARK_CLASS_REGEX .findFirstIn (el.getClassName).isDefined ) {
907+ if (skipClass (el.getClassName)) {
895908 lastSparkMethod = if (el.getMethodName == " <init>" ) {
896909 // Spark method is a constructor; get its class name
897910 el.getClassName.substring(el.getClassName.lastIndexOf('.' ) + 1 )
0 commit comments