Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import scala.reflect.ClassTag
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDDOperationScope
import org.apache.spark.streaming.{Time, Duration, StreamingContext}
import org.apache.spark.streaming.scheduler.RateController
import org.apache.spark.streaming.scheduler.rate.{RateEstimator, NoopRateEstimator, PIDRateEstimator}
import org.apache.spark.util.Utils

/**
Expand All @@ -47,6 +49,25 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext)
/** This is an unique identifier for the input stream. */
val id = ssc.getNewInputStreamId()


private def resolveRateEstimator(configString: String): RateEstimator = configString match {
case "pid" => new PIDRateEstimator(ssc.graph.batchDuration.milliseconds)
case _ => new NoopRateEstimator()
}
/**
* A rate estimator configured by the user to compute a dynamic ingestion bound for this stream.
* @see `RateEstimator`
*/
protected [streaming] val rateEstimator =
resolveRateEstimator(ssc.conf
.getOption("spark.streaming.RateEstimator")
.getOrElse("noop"))

// Keep track of the freshest rate for this stream using the rateEstimator
protected[streaming] val rateController: RateController = new RateController(id, rateEstimator) {
override def publish(rate: Long): Unit = ()
}

/** A human-readable name of this InputDStream */
private[streaming] def name: String = {
// e.g. FlumePollingDStream -> "Flume polling stream"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.storage.BlockId
import org.apache.spark.streaming._
import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.streaming.scheduler.StreamInputInfo
import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo}
import org.apache.spark.streaming.scheduler.rate.NoopRateEstimator
import org.apache.spark.streaming.util.WriteAheadLogUtils

/**
Expand All @@ -40,6 +41,14 @@ import org.apache.spark.streaming.util.WriteAheadLogUtils
abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext)
extends InputDStream[T](ssc_) {

/**
* Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker.
*/
override val rateController: RateController = new RateController(id, rateEstimator) {
override def publish(rate: Long): Unit =
ssc.scheduler.receiverTracker.sendRateUpdate(id, rate)
}

/**
* Gets the receiver object that will be sent to the worker nodes
* to receive data. This method needs to defined by any specific implementation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
}
eventLoop.start()

// Estimators receive updates from batch completion
ssc.graph.getInputStreams.map(_.rateController).foreach(ssc.addStreamingListener(_))
listenerBus.start(ssc.sparkContext)
receiverTracker = new ReceiverTracker(ssc)
inputInfoTracker = new InputInfoTracker(ssc)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.streaming.scheduler

import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.ThreadUtils

import scala.concurrent.{ExecutionContext, Future}

/**
* :: DeveloperApi ::
* A StreamingListener that receives batch completion updates, and maintains
* an estimate of the speed at which this stream should ingest messages,
* given an estimate computation from a `RateEstimator`
*/
@DeveloperApi
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This must be private[streaming]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume the same goes for RateEstimator, NoopRateEstimator, and PIDRateEstimator, at least for now, right ?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

private [streaming] abstract class RateController(val streamUID: Int, rateEstimator: RateEstimator)
extends StreamingListener with Serializable {

protected def publish(rate: Long): Unit

// Used to compute & publish the rate update asynchronously
@transient private val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update"))

private val rateLimit : AtomicLong = new AtomicLong(-1L)

// Asynchronous computation of the rate update
private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
Future[Unit] {
val newSpeed = rateEstimator.compute(time, elems, workDelay, waitDelay)
newSpeed foreach { s =>
rateLimit.set(s.toLong)
publish(getLatestRate())
}
} (executionContext)

def getLatestRate(): Long = rateLimit.get()

override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted){
val elements = batchCompleted.batchInfo.streamIdToInputInfo

for (
processingEnd <- batchCompleted.batchInfo.processingEndTime;
workDelay <- batchCompleted.batchInfo.processingDelay;
waitDelay <- batchCompleted.batchInfo.schedulingDelay;
elems <- elements.get(streamUID).map(_.numRecords)
) computeAndPublish(processingEnd, elems, workDelay, waitDelay)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.streaming.scheduler.rate

import org.apache.spark.Logging

/**
* Implements a proportional-integral-derivative (PID) controller which acts on
* the speed of ingestion of elements into Spark Streaming.
*
* @param batchDurationMillis the batch duration, in milliseconds
* @param proportional how much the correction should depend on the current
* error,
* @param integral how much the correction should depend on the accumulation
* of past errors,
* @param derivative how much the correction should depend on a prediction
* of future errors, based on current rate of change
*/
private[streaming] class PIDRateEstimator(batchIntervalMillis: Long,
proportional: Double = -1D,
integral: Double = -.2D,
derivative: Double = 0D)
extends RateEstimator with Logging {

private var init: Boolean = true
private var latestTime : Long = -1L
private var latestSpeed : Double = -1D
private var latestError : Double = -1L

if (batchIntervalMillis <= 0) logError("Specified batch interval ${batchIntervalMillis} " +
"in PIDRateEstimator is invalid.")

def compute(time: Long, // in milliseconds
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: incorrect indentation

elements: Long,
processingDelay: Long, // in milliseconds
schedulingDelay: Long // in milliseconds
): Option[Double] = {

this.synchronized {
if (time > latestTime && processingDelay > 0 && batchIntervalMillis > 0) {

// in seconds, should be close to batchDuration
val delaySinceUpdate = (time - latestTime).toDouble / 1000

// in elements/second
val processingSpeed = elements.toDouble / processingDelay * 1000

// in elements/second
val error = latestSpeed - processingSpeed

// in elements/second
val sumError = schedulingDelay.toDouble * processingSpeed / batchIntervalMillis

// in elements/(second ^ 2)
val dError = (error - latestError) / delaySinceUpdate
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if delaySinceUpdate could ever be 0.0 (for instance during driver recovery). The whole expression (and further down to newSpeed) would become NaN.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, delaySinceUpdate cannot ever be 0.0 there. I've added a unit test covering that case.


val newSpeed = (latestSpeed + proportional * error +
integral * sumError +
derivative * dError) max 0D
latestTime = time
if (init) {
latestSpeed = processingSpeed
latestError = 0D
init = false

None
} else {
latestSpeed = newSpeed
latestError = error

Some(newSpeed)
}
} else None
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.streaming.scheduler.rate

import org.apache.spark.annotation.DeveloperApi

/**
* :: DeveloperApi ::
* A component that estimates the rate at wich an InputDStream should ingest
* elements, based on updates at every batch completion.
*/
@DeveloperApi
private[streaming] trait RateEstimator extends Serializable {

/**
* Computes the number of elements the stream attached to this `RateEstimator`
* should ingest per second, given an update on the size and completion
* times of the latest batch.
*/
def compute(time: Long, elements: Long,
processingDelay: Long, schedulingDelay: Long): Option[Double]
}

/**
* The trivial rate estimator never sends an update
*/
private[streaming] class NoopRateEstimator extends RateEstimator {

def compute(time: Long, elements: Long,
processingDelay: Long, schedulingDelay: Long): Option[Double] = None
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global

import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.streaming.scheduler._
import org.apache.spark.streaming.scheduler.rate._

import org.scalatest.Matchers
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
import org.apache.spark.Logging

class StreamingListenerSuite extends TestSuiteBase with Matchers {

Expand Down Expand Up @@ -131,6 +132,37 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
}
}

// This test is long to run an may be dependent on your machine's
// characteristics (high variance in estimating processing speed on a
// small batch)
ignore("latest speed reporting") {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test should not be StreamingListenerSuite. Let me think about what should be right design for unit tests.

val operation = (d: DStream[Int]) => d.map(Thread.sleep(_))
val midInput = Seq.fill(10)(Seq.fill(100)(1))
val midSsc = setupStreams(midInput, operation)
val midLatestRate = new RateController(0,
new PIDRateEstimator(batchDuration.milliseconds, -1, 0, 0)){
def publish(r: Long): Unit = ()
}
midSsc.addStreamingListener(midLatestRate)
runStreams(midSsc, midInput.size, midInput.size)

val midSp = midLatestRate.getLatestRate()

// between two batch sizes that are both below the system's limits,
// the estimate of elements processed per batch should be comparable
val bigInput = Seq.fill(10)(Seq.fill(500)(1))
val bigSsc = setupStreams(bigInput, operation)
val bigLatestRate = new RateController(0,
new PIDRateEstimator(batchDuration.milliseconds, -1, 0, 0)){
def publish(r: Long): Unit = ()
}
bigSsc.addStreamingListener(bigLatestRate)
runStreams(bigSsc, bigInput.size, bigInput.size)

val bigSp = bigLatestRate.getLatestRate()
bigSp should (be >= (midSp / 2) and be <= (midSp * 2))
}

/** Check if a sequence of numbers is in increasing order */
def isInIncreasingOrder(seq: Seq[Long]): Boolean = {
for (i <- 1 until seq.size) {
Expand Down
Loading