Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d466d75
Changes for spark streaming UI
mubarak Jul 18, 2014
9d38d3c
[SPARK-1853] Show Streaming application code context (file, line numb…
mubarak Jul 18, 2014
1500deb
Changes in Spark Streaming UI
mubarak Jul 18, 2014
70f494f
Changes for SPARK-1853
mubarak Aug 10, 2014
5f3105a
Merge remote-tracking branch 'upstream/master'
mubarak Aug 11, 2014
1d90cc3
Changes for SPARK-1853
mubarak Aug 11, 2014
2a09ad6
Changes in Utils.scala for SPARK-1853
mubarak Aug 11, 2014
ccde038
Removing Utils import from MappedDStream
mubarak Aug 11, 2014
a207eb7
Fixing code review comments
mubarak Aug 18, 2014
5051c58
Getting return value of compute() into variable and call setCallSite(…
mubarak Aug 18, 2014
f51fd9f
Fixing scalastyle, Regex for Utils.getCallSite, and changing method n…
mubarak Aug 19, 2014
c26d933
Merge remote-tracking branch 'upstream/master'
mubarak Aug 20, 2014
33a7295
Fixing review comments: Merging both setCallSite methods
mubarak Aug 20, 2014
491a1eb
Removing streaming visibility from getRDDCreationCallSite in DStream
mubarak Aug 20, 2014
196121b
Merge remote-tracking branch 'upstream/master'
mubarak Aug 21, 2014
8c5d443
Merge remote-tracking branch 'upstream/master'
mubarak Sep 5, 2014
ceb43da
Changing default regex function name
mubarak Sep 5, 2014
c461cf4
Merge remote-tracking branch 'upstream/master'
mubarak Sep 6, 2014
b9ed945
Adding streaming utils
mubarak Sep 6, 2014
7baa427
Refactored getCallSite and setCallSite to make it simpler. Also added…
tdas Sep 19, 2014
904cd92
Merge remote-tracking branch 'apache-github/master' into streaming-ca…
tdas Sep 19, 2014
390b45d
Fixed minor bugs.
tdas Sep 19, 2014
dc54c71
Made changes based on PR comments.
tdas Sep 23, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1030,28 +1030,40 @@ class SparkContext(config: SparkConf) extends Logging {
}

/**
* Support function for API backtraces.
* Set the thread-local property for overriding the call sites
* of actions and RDDs.
*/
def setCallSite(site: String) {
setLocalProperty("externalCallSite", site)
def setCallSite(shortCallSite: String) {
setLocalProperty(CallSite.SHORT_FORM, shortCallSite)
}

/**
* Support function for API backtraces.
* Set the thread-local property for overriding the call sites
* of actions and RDDs.
*/
private[spark] def setCallSite(callSite: CallSite) {
setLocalProperty(CallSite.SHORT_FORM, callSite.shortForm)
setLocalProperty(CallSite.LONG_FORM, callSite.longForm)
}

/**
* Clear the thread-local property for overriding the call sites
* of actions and RDDs.
*/
def clearCallSite() {
setLocalProperty("externalCallSite", null)
setLocalProperty(CallSite.SHORT_FORM, null)
setLocalProperty(CallSite.LONG_FORM, null)
}

/**
* Capture the current user callsite and return a formatted version for printing. If the user
* has overridden the call site, this will return the user's version.
* has overridden the call site using `setCallSite()`, this will return the user's version.
*/
private[spark] def getCallSite(): CallSite = {
Option(getLocalProperty("externalCallSite")) match {
case Some(callSite) => CallSite(callSite, longForm = "")
case None => Utils.getCallSite
}
Option(getLocalProperty(CallSite.SHORT_FORM)).map { case shortCallSite =>
val longCallSite = Option(getLocalProperty(CallSite.LONG_FORM)).getOrElse("")
CallSite(shortCallSite, longCallSite)
}.getOrElse(Utils.getCallSite())
}

/**
Expand Down
7 changes: 4 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.rdd

import java.util.Random
import java.util.{Properties, Random}

import scala.collection.{mutable, Map}
import scala.collection.mutable.ArrayBuffer
Expand All @@ -41,7 +41,7 @@ import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}

Expand Down Expand Up @@ -1224,7 +1224,8 @@ abstract class RDD[T: ClassTag](
private var storageLevel: StorageLevel = StorageLevel.NONE

/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
@transient private[spark] val creationSite = Utils.getCallSite
@transient private[spark] val creationSite = sc.getCallSite()

private[spark] def getCreationSite: String = Option(creationSite).map(_.shortForm).getOrElse("")

private[spark] def elementClassTag: ClassTag[T] = classTag[T]
Expand Down
27 changes: 20 additions & 7 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream,
/** CallSite represents a place in user code. It can have a short and a long form. */
private[spark] case class CallSite(shortForm: String, longForm: String)

private[spark] object CallSite {
val SHORT_FORM = "callSite.short"
val LONG_FORM = "callSite.long"
}

/**
* Various utility methods used by Spark.
*/
Expand Down Expand Up @@ -859,18 +864,26 @@ private[spark] object Utils extends Logging {
}
}

/**
* A regular expression to match classes of the "core" Spark API that we want to skip when
* finding the call site of a method.
*/
private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
/** Default filtering function for finding call sites using `getCallSite`. */
private def coreExclusionFunction(className: String): Boolean = {
// A regular expression to match classes of the "core" Spark API that we want to skip when
// finding the call site of a method.
val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
val SCALA_CLASS_REGEX = """^scala""".r
val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined
val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined
// If the class is a Spark internal class or a Scala class, then exclude.
isSparkCoreClass || isScalaClass
}

/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
*
* @param skipClass Function that is used to exclude non-user-code classes.
*/
def getCallSite: CallSite = {
def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = {
val trace = Thread.currentThread.getStackTrace()
.filterNot { ste:StackTraceElement =>
// When running under some profilers, the current stack trace might contain some bogus
Expand All @@ -891,7 +904,7 @@ private[spark] object Utils extends Logging {

for (el <- trace) {
if (insideSpark) {
if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName).isDefined) {
if (skipClass(el.getClassName)) {
lastSparkMethod = if (el.getMethodName == "<init>") {
// Spark method is a constructor; get its class name
el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.receiver.{ActorSupervisorStrategy, ActorReceiver, Receiver}
import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver}
import org.apache.spark.streaming.scheduler._
import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab}
import org.apache.spark.util.MetadataCleaner

/**
* Main entry point for Spark Streaming functionality. It provides methods used to create
Expand Down Expand Up @@ -448,6 +447,7 @@ class StreamingContext private[streaming] (
throw new SparkException("StreamingContext has already been stopped")
}
validate()
sparkContext.setCallSite(DStream.getCreationSite())
scheduler.start()
state = Started
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
import scala.deprecated
import scala.collection.mutable.HashMap
import scala.reflect.ClassTag
import scala.util.matching.Regex

import org.apache.spark.{Logging, SparkException}
import org.apache.spark.rdd.{BlockRDD, RDD}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.scheduler.Job
import org.apache.spark.util.MetadataCleaner
import org.apache.spark.util.{CallSite, MetadataCleaner}

/**
* A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
Expand Down Expand Up @@ -106,6 +107,9 @@ abstract class DStream[T: ClassTag] (
/** Return the StreamingContext associated with this DStream */
def context = ssc

/* Set the creation call site */
private[streaming] val creationSite = DStream.getCreationSite()

/** Persist the RDDs of this DStream with the given storage level */
def persist(level: StorageLevel): DStream[T] = {
if (this.isInitialized) {
Expand Down Expand Up @@ -272,43 +276,41 @@ abstract class DStream[T: ClassTag] (
}

/**
* Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal
* method that should not be called directly.
* Get the RDD corresponding to the given time; either retrieve it from cache
* or compute-and-cache it.
*/
private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = {
// If this DStream was not initialized (i.e., zeroTime not set), then do it
// If RDD was already generated, then retrieve it from HashMap
generatedRDDs.get(time) match {

// If an RDD was already generated and is being reused, then
// probably all RDDs in this DStream will be reused and hence should be cached
case Some(oldRDD) => Some(oldRDD)

// if RDD was not generated, and if the time is valid
// (based on sliding time of this DStream), then generate the RDD
case None => {
if (isTimeValid(time)) {
compute(time) match {
case Some(newRDD) =>
if (storageLevel != StorageLevel.NONE) {
newRDD.persist(storageLevel)
logInfo("Persisting RDD " + newRDD.id + " for time " +
time + " to " + storageLevel + " at time " + time)
}
if (checkpointDuration != null &&
(time - zeroTime).isMultipleOf(checkpointDuration)) {
newRDD.checkpoint()
logInfo("Marking RDD " + newRDD.id + " for time " + time +
" for checkpointing at time " + time)
}
generatedRDDs.put(time, newRDD)
Some(newRDD)
case None =>
None
// If RDD was already generated, then retrieve it from HashMap,
// or else compute the RDD
generatedRDDs.get(time).orElse {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are just a refactoring of the code (from case Some and case None to Option.orElse), with no change in the logic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirmed that this logic is the same as before

// Compute the RDD if time is valid (e.g. correct time in a sliding window)
// of RDD generation, else generate nothing.
if (isTimeValid(time)) {
// Set the thread-local property for call sites to this DStream's creation site
// such that RDDs generated by compute gets that as their creation site.
// Note that this `getOrCompute` may get called from another DStream which may have
// set its own call site. So we store its call site in a temporary variable,
// set this DStream's creation site, generate RDDs and then restore the previous call site.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about explaining this from top down, start with what we're trying to do (set RDD call sites properly) to how we're doing it (using thread-local property). Right now it's easy to get lost if the reader isn't familiar with how call sites are set through SparkContext

val prevCallSite = ssc.sparkContext.getCallSite()
ssc.sparkContext.setCallSite(creationSite)
val rddOption = compute(time)
ssc.sparkContext.setCallSite(prevCallSite)

rddOption.foreach { case newRDD =>
// Register the generated RDD for caching and checkpointing
if (storageLevel != StorageLevel.NONE) {
newRDD.persist(storageLevel)
logDebug(s"Persisting RDD ${newRDD.id} for time $time to $storageLevel")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to change this from info to debug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes i did. too verbose.

}
} else {
None
if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) {
newRDD.checkpoint()
logInfo(s"Marking RDD ${newRDD.id} for time $time for checkpointing")
}
generatedRDDs.put(time, newRDD)
}
rddOption
} else {
None
}
}
}
Expand Down Expand Up @@ -799,3 +801,29 @@ abstract class DStream[T: ClassTag] (
this
}
}

private[streaming] object DStream {

/** Get the creation site of a DStream from the stack trace of when the DStream is created. */
def getCreationSite(): CallSite = {
val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r
val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r
val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r
val SCALA_CLASS_REGEX = """^scala""".r

/** Filtering function that excludes non-user classes for a streaming application */
def streamingExclustionFunction(className: String): Boolean = {
def doesMatch(r: Regex) = r.findFirstIn(className).isDefined
val isSparkClass = doesMatch(SPARK_CLASS_REGEX)
val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX)
val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX)
val isScalaClass = doesMatch(SCALA_CLASS_REGEX)

// If the class is a spark example class or a streaming test class then it is considered
// as a streaming application class and don't exclude. Otherwise, exclude any
// non-Spark and non-Scala class, as the rest would streaming application classes.
(isSparkClass || isScalaClass) && !isSparkExampleClass && !isSparkStreamingTestClass
}
org.apache.spark.util.Utils.getCallSite(streamingExclustionFunction)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ package org.apache.spark.streaming

import java.util.concurrent.atomic.AtomicInteger

import scala.language.postfixOps

import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.util.{MetadataCleaner, Utils}
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.apache.spark.util.Utils
import org.scalatest.{Assertions, BeforeAndAfter, FunSuite}
import org.scalatest.concurrent.Timeouts
import org.scalatest.concurrent.Eventually._
import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.scalatest.time.SpanSugar._

Expand Down Expand Up @@ -257,6 +260,10 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
assert(exception.getMessage.contains("transform"), "Expected exception not thrown")
}

test("DStream and generated RDD creation sites") {
testPackage.test()
}

def addInputStream(s: StreamingContext): DStream[Int] = {
val input = (1 to 100).map(i => (1 to i))
val inputStream = new TestInputStream(s, input, 1)
Expand Down Expand Up @@ -293,3 +300,37 @@ class TestReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging
object TestReceiver {
val counter = new AtomicInteger(1)
}

/** Streaming application for testing DStream and RDD creation sites */
package object testPackage extends Assertions {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: why package? Also classes / objects should be capitalized. Actually I would just put this in StreamingContextSuite and do away with this class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a package so that I can exclude org.apache.spark.streaming.testPackage in the filter function to Utils.getCallSite() and treat this package as a user-code for testing. Also, I was following the strategy used for testing Spark core call site in SparkContextInfoSuite.

Regarding naming strategy, besides the example in SparkContextInfoSuite, the official Scala style guide says

  1. Objects follow the class naming convention (camelCase with a capital first letter) except when attempting to mimic a package or a function.
  2. Scala packages should follow the Java package naming conventions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I did not realize there's already another suite that does that. It's ok to keep it then.

def test() {
val conf = new SparkConf().setMaster("local").setAppName("CreationSite test")
val ssc = new StreamingContext(conf , Milliseconds(100))
try {
val inputStream = ssc.receiverStream(new TestReceiver)

// Verify creation site of DStream
val creationSite = inputStream.creationSite
assert(creationSite.shortForm.contains("receiverStream") &&
creationSite.shortForm.contains("StreamingContextSuite")
)
assert(creationSite.longForm.contains("testPackage"))

// Verify creation site of generated RDDs
var rddGenerated = false
var rddCreationSiteCorrect = true

inputStream.foreachRDD { rdd =>
rddCreationSiteCorrect = rdd.creationSite == creationSite
rddGenerated = true
}
ssc.start()

eventually(timeout(10000 millis), interval(10 millis)) {
assert(rddGenerated && rddCreationSiteCorrect, "RDD creation site was not correct")
}
} finally {
ssc.stop()
}
}
}