diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index d58c99a8ff32..0ecb3147ac39 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -22,6 +22,8 @@ import scala.reflect.ClassTag import org.apache.spark.SparkContext import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.streaming.{Time, Duration, StreamingContext} +import org.apache.spark.streaming.scheduler.RateController +import org.apache.spark.streaming.scheduler.rate.{RateEstimator, NoopRateEstimator, PIDRateEstimator} import org.apache.spark.util.Utils /** @@ -47,6 +49,25 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) /** This is an unique identifier for the input stream. */ val id = ssc.getNewInputStreamId() + + private def resolveRateEstimator(configString: String): RateEstimator = configString match { + case "pid" => new PIDRateEstimator(ssc.graph.batchDuration.milliseconds) + case _ => new NoopRateEstimator() + } + /** + * A rate estimator configured by the user to compute a dynamic ingestion bound for this stream. + * @see `RateEstimator` + */ + protected [streaming] val rateEstimator = + resolveRateEstimator(ssc.conf + .getOption("spark.streaming.RateEstimator") + .getOrElse("noop")) + + // Keep track of the freshest rate for this stream using the rateEstimator + protected[streaming] val rateController: RateController = new RateController(id, rateEstimator) { + override def publish(rate: Long): Unit = () + } + /** A human-readable name of this InputDStream */ private[streaming] def name: String = { // e.g. FlumePollingDStream -> "Flume polling stream" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index a50f0efc030c..39e369476382 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -24,7 +24,8 @@ import org.apache.spark.storage.BlockId import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.streaming.scheduler.StreamInputInfo +import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo} +import org.apache.spark.streaming.scheduler.rate.NoopRateEstimator import org.apache.spark.streaming.util.WriteAheadLogUtils /** @@ -40,6 +41,14 @@ import org.apache.spark.streaming.util.WriteAheadLogUtils abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { + /** + * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. + */ + override val rateController: RateController = new RateController(id, rateEstimator) { + override def publish(rate: Long): Unit = + ssc.scheduler.receiverTracker.sendRateUpdate(id, rate) + } + /** * Gets the receiver object that will be sent to the worker nodes * to receive data. This method needs to defined by any specific implementation diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 4af9b6d3b56a..d3f257429c95 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -66,6 +66,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } eventLoop.start() + // Estimators receive updates from batch completion + ssc.graph.getInputStreams.map(_.rateController).foreach(ssc.addStreamingListener(_)) listenerBus.start(ssc.sparkContext) receiverTracker = new ReceiverTracker(ssc) inputInfoTracker = new InputInfoTracker(ssc) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala new file mode 100644 index 000000000000..82244498cc05 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.streaming.scheduler.rate.RateEstimator +import org.apache.spark.util.ThreadUtils + +import scala.concurrent.{ExecutionContext, Future} + +/** + * :: DeveloperApi :: + * A StreamingListener that receives batch completion updates, and maintains + * an estimate of the speed at which this stream should ingest messages, + * given an estimate computation from a `RateEstimator` + */ +@DeveloperApi +private [streaming] abstract class RateController(val streamUID: Int, rateEstimator: RateEstimator) + extends StreamingListener with Serializable { + + protected def publish(rate: Long): Unit + + // Used to compute & publish the rate update asynchronously + @transient private val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update")) + + private val rateLimit : AtomicLong = new AtomicLong(-1L) + + // Asynchronous computation of the rate update + private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit = + Future[Unit] { + val newSpeed = rateEstimator.compute(time, elems, workDelay, waitDelay) + newSpeed foreach { s => + rateLimit.set(s.toLong) + publish(getLatestRate()) + } + } (executionContext) + + def getLatestRate(): Long = rateLimit.get() + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted){ + val elements = batchCompleted.batchInfo.streamIdToInputInfo + + for ( + processingEnd <- batchCompleted.batchInfo.processingEndTime; + workDelay <- batchCompleted.batchInfo.processingDelay; + waitDelay <- batchCompleted.batchInfo.schedulingDelay; + elems <- elements.get(streamUID).map(_.numRecords) + ) computeAndPublish(processingEnd, elems, workDelay, waitDelay) + } + +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala new file mode 100644 index 000000000000..ab1970e9aadc --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import org.apache.spark.Logging + +/** + * Implements a proportional-integral-derivative (PID) controller which acts on + * the speed of ingestion of elements into Spark Streaming. + * + * @param batchDurationMillis the batch duration, in milliseconds + * @param proportional how much the correction should depend on the current + * error, + * @param integral how much the correction should depend on the accumulation + * of past errors, + * @param derivative how much the correction should depend on a prediction + * of future errors, based on current rate of change + */ +private[streaming] class PIDRateEstimator(batchIntervalMillis: Long, + proportional: Double = -1D, + integral: Double = -.2D, + derivative: Double = 0D) + extends RateEstimator with Logging { + + private var init: Boolean = true + private var latestTime : Long = -1L + private var latestSpeed : Double = -1D + private var latestError : Double = -1L + + if (batchIntervalMillis <= 0) logError("Specified batch interval ${batchIntervalMillis} " + + "in PIDRateEstimator is invalid.") + + def compute(time: Long, // in milliseconds + elements: Long, + processingDelay: Long, // in milliseconds + schedulingDelay: Long // in milliseconds + ): Option[Double] = { + + this.synchronized { + if (time > latestTime && processingDelay > 0 && batchIntervalMillis > 0) { + + // in seconds, should be close to batchDuration + val delaySinceUpdate = (time - latestTime).toDouble / 1000 + + // in elements/second + val processingSpeed = elements.toDouble / processingDelay * 1000 + + // in elements/second + val error = latestSpeed - processingSpeed + + // in elements/second + val sumError = schedulingDelay.toDouble * processingSpeed / batchIntervalMillis + + // in elements/(second ^ 2) + val dError = (error - latestError) / delaySinceUpdate + + val newSpeed = (latestSpeed + proportional * error + + integral * sumError + + derivative * dError) max 0D + latestTime = time + if (init) { + latestSpeed = processingSpeed + latestError = 0D + init = false + + None + } else { + latestSpeed = newSpeed + latestError = error + + Some(newSpeed) + } + } else None + } + } + +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala new file mode 100644 index 000000000000..1e1ccf135ad7 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * A component that estimates the rate at wich an InputDStream should ingest + * elements, based on updates at every batch completion. + */ +@DeveloperApi +private[streaming] trait RateEstimator extends Serializable { + + /** + * Computes the number of elements the stream attached to this `RateEstimator` + * should ingest per second, given an update on the size and completion + * times of the latest batch. + */ + def compute(time: Long, elements: Long, + processingDelay: Long, schedulingDelay: Long): Option[Double] +} + +/** + * The trivial rate estimator never sends an update + */ +private[streaming] class NoopRateEstimator extends RateEstimator { + + def compute(time: Long, elements: Long, + processingDelay: Long, schedulingDelay: Long): Option[Double] = None +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 4bc1dd4a30fc..8f1f87cc2238 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -21,15 +21,16 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global +import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.scheduler.rate._ import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.Logging class StreamingListenerSuite extends TestSuiteBase with Matchers { @@ -131,6 +132,37 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } } + // This test is long to run an may be dependent on your machine's + // characteristics (high variance in estimating processing speed on a + // small batch) + ignore("latest speed reporting") { + val operation = (d: DStream[Int]) => d.map(Thread.sleep(_)) + val midInput = Seq.fill(10)(Seq.fill(100)(1)) + val midSsc = setupStreams(midInput, operation) + val midLatestRate = new RateController(0, + new PIDRateEstimator(batchDuration.milliseconds, -1, 0, 0)){ + def publish(r: Long): Unit = () + } + midSsc.addStreamingListener(midLatestRate) + runStreams(midSsc, midInput.size, midInput.size) + + val midSp = midLatestRate.getLatestRate() + + // between two batch sizes that are both below the system's limits, + // the estimate of elements processed per batch should be comparable + val bigInput = Seq.fill(10)(Seq.fill(500)(1)) + val bigSsc = setupStreams(bigInput, operation) + val bigLatestRate = new RateController(0, + new PIDRateEstimator(batchDuration.milliseconds, -1, 0, 0)){ + def publish(r: Long): Unit = () + } + bigSsc.addStreamingListener(bigLatestRate) + runStreams(bigSsc, bigInput.size, bigInput.size) + + val bigSp = bigLatestRate.getLatestRate() + bigSp should (be >= (midSp / 2) and be <= (midSp * 2)) + } + /** Check if a sequence of numbers is in increasing order */ def isInIncreasingOrder(seq: Seq[Long]): Boolean = { for (i <- 1 until seq.size) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala new file mode 100644 index 000000000000..a1722b2ed606 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import org.apache.spark.SparkFunSuite +import org.scalatest._ +import org.scalatest.Matchers +import org.scalatest.Inspectors._ + +import scala.util.Random + +class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { + + test("first bound is None"){ + val p = new PIDRateEstimator(20, -1D, 0D, 0D) + p.compute(0, 10, 10, 0) should equal (None) + } + + test("second bound is rate"){ + val p = new PIDRateEstimator(20, -1D, 0D, 0D) + p.compute(0, 10, 10, 0) + // 1000 elements / s + p.compute(10, 10, 10, 0) should equal (Some(1000)) + } + + test("works even with no time between updates"){ + val p = new PIDRateEstimator(20, -1D, 0D, 0D) + p.compute(0, 10, 10, 0) + p.compute(10, 10, 10, 0) + p.compute(10, 10, 10, 0) should equal (None) + } + + test("works even with a zero batch interval"){ + val p = new PIDRateEstimator(0, -1D, 0D, 0D) + p.compute(0, 10, 10, 0) should equal (None) + p.compute(10, 10, 10, 0) should equal (None) + } + + test("bound is never negative"){ + val p = new PIDRateEstimator(20, -1D, -1D, 0D) + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.fill(50)(0) // no processing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(100) // strictly positive accumulation + val res = for (i <- List.range(0, 50)) yield + p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal (None) + res.tail should equal (List.fill(49)(Some(0D))) + } + + + test("with no accumulated or positive error, |I| > 0, follow the processing speed"){ + val p = new PIDRateEstimator(20, -1D, -1D, 0D) + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.tabulate(50)(x => x * 20) // increasing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(0) + val res = for (i <- List.range(0, 50)) yield + p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal (None) + res.tail should equal (List.tabulate(50)(x => Some(x * 1000D)).tail) + } + + test("with no accumulated but some positive error, |I| > 0, follow the processing speed"){ + val p = new PIDRateEstimator(20, -1D, -1D, 0D) + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.tabulate(50)(x => (50-x) * 20) // decreasing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(0) + val res = for (i <- List.range(0, 50)) yield + p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal (None) + res.tail should equal (List.tabulate(50)(x => Some((50-x) * 1000D)).tail) + } + + test("with some accumulated and some positive error, |I| > 0, stay below the processing speed"){ + val p = new PIDRateEstimator(20, -1D, -.01D, 0D) + val times = List.tabulate(50)(x => x * 20) // every 20ms + val rng = new Random() + val elements = List.tabulate(50)(x => rng.nextInt(1000)) + val procDelayMs = 20 + val proc = List.fill(50)(procDelayMs) // 20ms of processing + val sched = List.tabulate(50)(x => rng.nextInt(19)) // random wait + val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000) + + val res = for (i <- List.range(0, 50)) yield + p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal (None) + forAll (List.range(1, 50)){ (n) => + res(n) should not be None + if (res(n).get > 0 && sched(n) > 0) { + res(n).get should be < speeds(n) + } + } + } + +}