Skip to content

Commit f74aad1

Browse files
committed
Avoid Option while generating call site & add unit tests
1 parent d2b4980 commit f74aad1

File tree

4 files changed

+47
-12
lines changed

4 files changed

+47
-12
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,8 @@ class SparkContext(
877877
* has overridden the call site, this will return the user's version.
878878
*/
879879
private[spark] def getCallSite(): String = {
880-
Option(getLocalProperty("externalCallSite")).getOrElse(Utils.formatCallSiteInfo())
880+
val defaultCallSite = Utils.getCallSiteInfo
881+
Option(getLocalProperty("externalCallSite")).getOrElse(defaultCallSite.toString)
881882
}
882883

883884
/**

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,7 @@ abstract class RDD[T: ClassTag](
10411041

10421042
/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
10431043
@transient private[spark] val creationSiteInfo = Utils.getCallSiteInfo
1044-
private[spark] def getCreationSite = Utils.formatCallSiteInfo(creationSiteInfo)
1044+
private[spark] def getCreationSite: String = creationSiteInfo.toString
10451045

10461046
private[spark] def elementClassTag: ClassTag[T] = classTag[T]
10471047

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -676,16 +676,22 @@ private[spark] object Utils extends Logging {
676676
private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
677677

678678
private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
679-
val firstUserLine: Int, val firstUserClass: String)
679+
val firstUserLine: Int, val firstUserClass: String) {
680+
681+
/** Returns a printable version of the call site info suitable for logs. */
682+
override def toString = {
683+
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
684+
}
685+
}
680686

681687
/**
682688
* When called inside a class in the spark package, returns the name of the user code class
683689
* (outside the spark package) that called into Spark, as well as which Spark method they called.
684690
* This is used, for example, to tell users where in their code each RDD got created.
685691
*/
686692
def getCallSiteInfo: CallSiteInfo = {
687-
val trace = Thread.currentThread.getStackTrace().filter( el =>
688-
((!el.getMethodName.contains("getStackTrace")) && (el.getClassName != "scala.Option")))
693+
val trace = Thread.currentThread.getStackTrace()
694+
.filterNot(_.getMethodName.contains("getStackTrace"))
689695

690696
// Keep crawling up the stack trace until we find the first function not inside of the spark
691697
// package. We track the last (shallowest) contiguous Spark method. This might be an RDD
@@ -718,12 +724,6 @@ private[spark] object Utils extends Logging {
718724
new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
719725
}
720726

721-
/** Returns a printable version of the call site info suitable for logs. */
722-
def formatCallSiteInfo(callSiteInfo: CallSiteInfo = Utils.getCallSiteInfo) = {
723-
"%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
724-
callSiteInfo.firstUserLine)
725-
}
726-
727727
/** Return a string containing part of a file from byte 'start' to 'end'. */
728728
def offsetBytes(path: String, start: Long, end: Long): String = {
729729
val file = new File(path)

core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark
1919

20-
import org.scalatest.FunSuite
20+
import org.scalatest.{Assertions, FunSuite}
2121

2222
class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
2323
test("getPersistentRDDs only returns RDDs that are marked as cached") {
@@ -56,4 +56,38 @@ class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
5656
rdd.collect()
5757
assert(sc.getRDDStorageInfo.size === 1)
5858
}
59+
60+
test("call sites report correct locations") {
61+
sc = new SparkContext("local", "test")
62+
testPackage.runCallSiteTest(sc)
63+
}
64+
}
65+
66+
/** Call site must be outside of usual org.apache.spark packages (see Utils#SPARK_CLASS_REGEX). */
67+
package object testPackage extends Assertions {
68+
private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r
69+
70+
def runCallSiteTest(sc: SparkContext) {
71+
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
72+
val rddCreationSite = rdd.getCreationSite
73+
val curCallSite = sc.getCallSite() // note: 2 lines after definition of "rdd"
74+
75+
val rddCreationLine = rddCreationSite match {
76+
case CALL_SITE_REGEX(func, file, line) => {
77+
assert(func === "makeRDD")
78+
assert(file === "SparkContextInfoSuite.scala")
79+
line.toInt
80+
}
81+
case _ => fail("Did not match expected call site format")
82+
}
83+
84+
curCallSite match {
85+
case CALL_SITE_REGEX(func, file, line) => {
86+
assert(func === "getCallSite") // this is correct because we called it from outside of Spark
87+
assert(file === "SparkContextInfoSuite.scala")
88+
assert(line.toInt === rddCreationLine.toInt + 2)
89+
}
90+
case _ => fail("Did not match expected call site format")
91+
}
92+
}
5993
}

0 commit comments

Comments
 (0)