Skip to content

Commit 0a1d2ca

Browse files
dragostdas
authored andcommitted
[SPARK-8979] Add a PID based rate estimator
Based on #7600 /cc tdas Author: Iulian Dragos <[email protected]> Author: François Garillot <[email protected]> Closes #7648 from dragos/topic/streaming-bp/pid and squashes the following commits: aa5b097 [Iulian Dragos] Add more comments, made all PID constant parameters positive, a couple more tests. 93b74f8 [Iulian Dragos] Better explanation of historicalError. 7975b0c [Iulian Dragos] Add configuration for PID. 26cfd78 [Iulian Dragos] A couple of variable renames. d0bdf7c [Iulian Dragos] Update to latest version of the code, various style and name improvements. d58b845 [François Garillot] [SPARK-8979][Streaming] Implements a PIDRateEstimator
1 parent e8bdcde commit 0a1d2ca

File tree

4 files changed

+276
-5
lines changed

4 files changed

+276
-5
lines changed

streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
4646
*/
4747
override protected[streaming] val rateController: Option[RateController] = {
4848
if (RateController.isBackPressureEnabled(ssc.conf)) {
49-
RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) }
49+
Some(new ReceiverRateController(id, RateEstimator.create(ssc.conf, ssc.graph.batchDuration)))
5050
} else {
5151
None
5252
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.streaming.scheduler.rate
19+
20+
/**
21+
* Implements a proportional-integral-derivative (PID) controller which acts on
22+
* the speed of ingestion of elements into Spark Streaming. A PID controller works
23+
* by calculating an '''error''' between a measured output and a desired value. In the
24+
* case of Spark Streaming the error is the difference between the measured processing
25+
* rate (number of elements/processing delay) and the previous rate.
26+
*
27+
* @see https://en.wikipedia.org/wiki/PID_controller
28+
*
29+
* @param batchDurationMillis the batch duration, in milliseconds
30+
* @param proportional how much the correction should depend on the current
31+
* error. This term usually provides the bulk of correction and should be positive or zero.
32+
* A value too large would make the controller overshoot the setpoint, while a small value
33+
* would make the controller too insensitive. The default value is 1.
34+
* @param integral how much the correction should depend on the accumulation
35+
* of past errors. This value should be positive or 0. This term accelerates the movement
36+
* towards the desired value, but a large value may lead to overshooting. The default value
37+
* is 0.2.
38+
* @param derivative how much the correction should depend on a prediction
39+
* of future errors, based on current rate of change. This value should be positive or 0.
40+
* This term is not used very often, as it impacts stability of the system. The default
41+
* value is 0.
42+
*/
43+
private[streaming] class PIDRateEstimator(
44+
batchIntervalMillis: Long,
45+
proportional: Double = 1D,
46+
integral: Double = .2D,
47+
derivative: Double = 0D)
48+
extends RateEstimator {
49+
50+
private var firstRun: Boolean = true
51+
private var latestTime: Long = -1L
52+
private var latestRate: Double = -1D
53+
private var latestError: Double = -1L
54+
55+
require(
56+
batchIntervalMillis > 0,
57+
s"Specified batch interval $batchIntervalMillis in PIDRateEstimator is invalid.")
58+
require(
59+
proportional >= 0,
60+
s"Proportional term $proportional in PIDRateEstimator should be >= 0.")
61+
require(
62+
integral >= 0,
63+
s"Integral term $integral in PIDRateEstimator should be >= 0.")
64+
require(
65+
derivative >= 0,
66+
s"Derivative term $derivative in PIDRateEstimator should be >= 0.")
67+
68+
69+
def compute(time: Long, // in milliseconds
70+
numElements: Long,
71+
processingDelay: Long, // in milliseconds
72+
schedulingDelay: Long // in milliseconds
73+
): Option[Double] = {
74+
75+
this.synchronized {
76+
if (time > latestTime && processingDelay > 0 && batchIntervalMillis > 0) {
77+
78+
// in seconds, should be close to batchDuration
79+
val delaySinceUpdate = (time - latestTime).toDouble / 1000
80+
81+
// in elements/second
82+
val processingRate = numElements.toDouble / processingDelay * 1000
83+
84+
// In our system `error` is the difference between the desired rate and the measured rate
85+
// based on the latest batch information. We consider the desired rate to be latest rate,
86+
// which is what this estimator calculated for the previous batch.
87+
// in elements/second
88+
val error = latestRate - processingRate
89+
90+
// The error integral, based on schedulingDelay as an indicator for accumulated errors.
91+
// A scheduling delay s corresponds to s * processingRate overflowing elements. Those
92+
// are elements that couldn't be processed in previous batches, leading to this delay.
93+
// In the following, we assume the processingRate didn't change too much.
94+
// From the number of overflowing elements we can calculate the rate at which they would be
95+
// processed by dividing it by the batch interval. This rate is our "historical" error,
96+
// or integral part, since if we subtracted this rate from the previous "calculated rate",
97+
// there wouldn't have been any overflowing elements, and the scheduling delay would have
98+
// been zero.
99+
// (in elements/second)
100+
val historicalError = schedulingDelay.toDouble * processingRate / batchIntervalMillis
101+
102+
// in elements/(second ^ 2)
103+
val dError = (error - latestError) / delaySinceUpdate
104+
105+
val newRate = (latestRate - proportional * error -
106+
integral * historicalError -
107+
derivative * dError).max(0.0)
108+
latestTime = time
109+
if (firstRun) {
110+
latestRate = processingRate
111+
latestError = 0D
112+
firstRun = false
113+
114+
None
115+
} else {
116+
latestRate = newRate
117+
latestError = error
118+
119+
Some(newRate)
120+
}
121+
} else None
122+
}
123+
}
124+
}

streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.streaming.scheduler.rate
1919

2020
import org.apache.spark.SparkConf
2121
import org.apache.spark.SparkException
22+
import org.apache.spark.streaming.Duration
2223

2324
/**
2425
* A component that estimates the rate at wich an InputDStream should ingest
@@ -48,12 +49,21 @@ object RateEstimator {
4849
/**
4950
* Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`.
5051
*
51-
* @return None if there is no configured estimator, otherwise an instance of RateEstimator
52+
* The only known estimator right now is `pid`.
53+
*
54+
* @return An instance of RateEstimator
5255
* @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any
5356
* known estimators.
5457
*/
55-
def create(conf: SparkConf): Option[RateEstimator] =
56-
conf.getOption("spark.streaming.backpressure.rateEstimator").map { estimator =>
57-
throw new IllegalArgumentException(s"Unkown rate estimator: $estimator")
58+
def create(conf: SparkConf, batchInterval: Duration): RateEstimator =
59+
conf.get("spark.streaming.backpressure.rateEstimator", "pid") match {
60+
case "pid" =>
61+
val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0)
62+
val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2)
63+
val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0)
64+
new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived)
65+
66+
case estimator =>
67+
throw new IllegalArgumentException(s"Unkown rate estimator: $estimator")
5868
}
5969
}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.streaming.scheduler.rate
19+
20+
import scala.util.Random
21+
22+
import org.scalatest.Inspectors.forAll
23+
import org.scalatest.Matchers
24+
25+
import org.apache.spark.{SparkConf, SparkFunSuite}
26+
import org.apache.spark.streaming.Seconds
27+
28+
class PIDRateEstimatorSuite extends SparkFunSuite with Matchers {
29+
30+
test("the right estimator is created") {
31+
val conf = new SparkConf
32+
conf.set("spark.streaming.backpressure.rateEstimator", "pid")
33+
val pid = RateEstimator.create(conf, Seconds(1))
34+
pid.getClass should equal(classOf[PIDRateEstimator])
35+
}
36+
37+
test("estimator checks ranges") {
38+
intercept[IllegalArgumentException] {
39+
new PIDRateEstimator(0, 1, 2, 3)
40+
}
41+
intercept[IllegalArgumentException] {
42+
new PIDRateEstimator(100, -1, 2, 3)
43+
}
44+
intercept[IllegalArgumentException] {
45+
new PIDRateEstimator(100, 0, -1, 3)
46+
}
47+
intercept[IllegalArgumentException] {
48+
new PIDRateEstimator(100, 0, 0, -1)
49+
}
50+
}
51+
52+
private def createDefaultEstimator: PIDRateEstimator = {
53+
new PIDRateEstimator(20, 1D, 0D, 0D)
54+
}
55+
56+
test("first bound is None") {
57+
val p = createDefaultEstimator
58+
p.compute(0, 10, 10, 0) should equal(None)
59+
}
60+
61+
test("second bound is rate") {
62+
val p = createDefaultEstimator
63+
p.compute(0, 10, 10, 0)
64+
// 1000 elements / s
65+
p.compute(10, 10, 10, 0) should equal(Some(1000))
66+
}
67+
68+
test("works even with no time between updates") {
69+
val p = createDefaultEstimator
70+
p.compute(0, 10, 10, 0)
71+
p.compute(10, 10, 10, 0)
72+
p.compute(10, 10, 10, 0) should equal(None)
73+
}
74+
75+
test("bound is never negative") {
76+
val p = new PIDRateEstimator(20, 1D, 1D, 0D)
77+
// prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing
78+
// this might point the estimator to try and decrease the bound, but we test it never
79+
// goes below zero, which would be nonsensical.
80+
val times = List.tabulate(50)(x => x * 20) // every 20ms
81+
val elements = List.fill(50)(0) // no processing
82+
val proc = List.fill(50)(20) // 20ms of processing
83+
val sched = List.fill(50)(100) // strictly positive accumulation
84+
val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
85+
res.head should equal(None)
86+
res.tail should equal(List.fill(49)(Some(0D)))
87+
}
88+
89+
test("with no accumulated or positive error, |I| > 0, follow the processing speed") {
90+
val p = new PIDRateEstimator(20, 1D, 1D, 0D)
91+
// prepare a series of batch updates, one every 20ms with an increasing number of processed
92+
// elements in each batch, but constant processing time, and no accumulated error. Even though
93+
// the integral part is non-zero, the estimated rate should follow only the proportional term
94+
val times = List.tabulate(50)(x => x * 20) // every 20ms
95+
val elements = List.tabulate(50)(x => x * 20) // increasing
96+
val proc = List.fill(50)(20) // 20ms of processing
97+
val sched = List.fill(50)(0)
98+
val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
99+
res.head should equal(None)
100+
res.tail should equal(List.tabulate(50)(x => Some(x * 1000D)).tail)
101+
}
102+
103+
test("with no accumulated but some positive error, |I| > 0, follow the processing speed") {
104+
val p = new PIDRateEstimator(20, 1D, 1D, 0D)
105+
// prepare a series of batch updates, one every 20ms with an decreasing number of processed
106+
// elements in each batch, but constant processing time, and no accumulated error. Even though
107+
// the integral part is non-zero, the estimated rate should follow only the proportional term,
108+
// asking for less and less elements
109+
val times = List.tabulate(50)(x => x * 20) // every 20ms
110+
val elements = List.tabulate(50)(x => (50 - x) * 20) // decreasing
111+
val proc = List.fill(50)(20) // 20ms of processing
112+
val sched = List.fill(50)(0)
113+
val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
114+
res.head should equal(None)
115+
res.tail should equal(List.tabulate(50)(x => Some((50 - x) * 1000D)).tail)
116+
}
117+
118+
test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") {
119+
val p = new PIDRateEstimator(20, 1D, .01D, 0D)
120+
val times = List.tabulate(50)(x => x * 20) // every 20ms
121+
val rng = new Random()
122+
val elements = List.tabulate(50)(x => rng.nextInt(1000))
123+
val procDelayMs = 20
124+
val proc = List.fill(50)(procDelayMs) // 20ms of processing
125+
val sched = List.tabulate(50)(x => rng.nextInt(19)) // random wait
126+
val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000)
127+
128+
val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
129+
res.head should equal(None)
130+
forAll(List.range(1, 50)) { (n) =>
131+
res(n) should not be None
132+
if (res(n).get > 0 && sched(n) > 0) {
133+
res(n).get should be < speeds(n)
134+
}
135+
}
136+
}
137+
}

0 commit comments

Comments
 (0)