Skip to content

Commit 4e8ac6e

Browse files
ericlhvanhovell
authored andcommitted
[SPARK-15735] Allow specifying min time to run in microbenchmarks
## What changes were proposed in this pull request? This makes microbenchmarks run for at least 2 seconds by default, to allow some time for jit compilation to kick in. ## How was this patch tested? Tested manually with existing microbenchmarks. This change is backwards compatible in that existing microbenchmarks which specified numIters per-case will still run exactly that number of iterations. Microbenchmarks which previously overrode defaultNumIters now override minNumIters. cc hvanhovell Author: Eric Liang <[email protected]> Author: Eric Liang <[email protected]> Closes #13472 from ericl/spark-15735.
1 parent ca70ab2 commit 4e8ac6e

File tree

1 file changed

+72
-37
lines changed

1 file changed

+72
-37
lines changed

core/src/main/scala/org/apache/spark/util/Benchmark.scala

Lines changed: 72 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717

1818
package org.apache.spark.util
1919

20+
import java.io.{OutputStream, PrintStream}
21+
2022
import scala.collection.mutable
2123
import scala.collection.mutable.ArrayBuffer
24+
import scala.concurrent.duration._
2225
import scala.util.Try
2326

27+
import org.apache.commons.io.output.TeeOutputStream
2428
import org.apache.commons.lang3.SystemUtils
2529

2630
/**
@@ -33,18 +37,37 @@ import org.apache.commons.lang3.SystemUtils
3337
*
3438
* The benchmark function takes one argument that is the iteration that's being run.
3539
*
36-
* If outputPerIteration is true, the timing for each run will be printed to stdout.
40+
* @param name name of this benchmark.
41+
* @param valuesPerIteration number of values used in the test case, used to compute rows/s.
42+
* @param minNumIters the min number of iterations that will be run per case, not counting warm-up.
43+
* @param warmupTime amount of time to spend running dummy case iterations for JIT warm-up.
44+
* @param minTime further iterations will be run for each case until this time is used up.
45+
* @param outputPerIteration if true, the timing for each run will be printed to stdout.
46+
* @param output optional output stream to write benchmark results to
3747
*/
3848
private[spark] class Benchmark(
3949
name: String,
4050
valuesPerIteration: Long,
41-
defaultNumIters: Int = 5,
42-
outputPerIteration: Boolean = false) {
51+
minNumIters: Int = 2,
52+
warmupTime: FiniteDuration = 2.seconds,
53+
minTime: FiniteDuration = 2.seconds,
54+
outputPerIteration: Boolean = false,
55+
output: Option[OutputStream] = None) {
56+
import Benchmark._
4357
val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case]
4458

59+
val out = if (output.isDefined) {
60+
new PrintStream(new TeeOutputStream(System.out, output.get))
61+
} else {
62+
System.out
63+
}
64+
4565
/**
4666
* Adds a case to run when run() is called. The given function will be run for several
4767
* iterations to collect timing statistics.
68+
*
69+
* @param name of the benchmark case
70+
* @param numIters if non-zero, forces exactly this many iterations to be run
4871
*/
4972
def addCase(name: String, numIters: Int = 0)(f: Int => Unit): Unit = {
5073
addTimerCase(name, numIters) { timer =>
@@ -58,9 +81,12 @@ private[spark] class Benchmark(
5881
* Adds a case with manual timing control. When the function is run, timing does not start
5982
* until timer.startTiming() is called within the given function. The corresponding
6083
* timer.stopTiming() method must be called before the function returns.
84+
*
85+
* @param name of the benchmark case
86+
* @param numIters if non-zero, forces exactly this many iterations to be run
6187
*/
6288
def addTimerCase(name: String, numIters: Int = 0)(f: Benchmark.Timer => Unit): Unit = {
63-
benchmarks += Benchmark.Case(name, f, if (numIters == 0) defaultNumIters else numIters)
89+
benchmarks += Benchmark.Case(name, f, numIters)
6490
}
6591

6692
/**
@@ -75,28 +101,63 @@ private[spark] class Benchmark(
75101

76102
val results = benchmarks.map { c =>
77103
println(" Running case: " + c.name)
78-
Benchmark.measure(valuesPerIteration, c.numIters, outputPerIteration)(c.fn)
104+
measure(valuesPerIteration, c.numIters)(c.fn)
79105
}
80106
println
81107

82108
val firstBest = results.head.bestMs
83109
// The results are going to be processor specific so it is useful to include that.
84-
println(Benchmark.getJVMOSInfo())
85-
println(Benchmark.getProcessorName())
86-
printf("%-40s %16s %12s %13s %10s\n", name + ":", "Best/Avg Time(ms)", "Rate(M/s)",
110+
out.println(Benchmark.getJVMOSInfo())
111+
out.println(Benchmark.getProcessorName())
112+
out.printf("%-40s %16s %12s %13s %10s\n", name + ":", "Best/Avg Time(ms)", "Rate(M/s)",
87113
"Per Row(ns)", "Relative")
88-
println("-" * 96)
114+
out.println("-" * 96)
89115
results.zip(benchmarks).foreach { case (result, benchmark) =>
90-
printf("%-40s %16s %12s %13s %10s\n",
116+
out.printf("%-40s %16s %12s %13s %10s\n",
91117
benchmark.name,
92118
"%5.0f / %4.0f" format (result.bestMs, result.avgMs),
93119
"%10.1f" format result.bestRate,
94120
"%6.1f" format (1000 / result.bestRate),
95121
"%3.1fX" format (firstBest / result.bestMs))
96122
}
97-
println
123+
out.println
98124
// scalastyle:on
99125
}
126+
127+
/**
128+
* Runs a single function `f` for iters, returning the average time the function took and
129+
* the rate of the function.
130+
*/
131+
def measure(num: Long, overrideNumIters: Int)(f: Timer => Unit): Result = {
132+
System.gc() // ensures garbage from previous cases don't impact this one
133+
val warmupDeadline = warmupTime.fromNow
134+
while (!warmupDeadline.isOverdue) {
135+
f(new Benchmark.Timer(-1))
136+
}
137+
val minIters = if (overrideNumIters != 0) overrideNumIters else minNumIters
138+
val minDuration = if (overrideNumIters != 0) 0 else minTime.toNanos
139+
val runTimes = ArrayBuffer[Long]()
140+
var i = 0
141+
while (i < minIters || runTimes.sum < minDuration) {
142+
val timer = new Benchmark.Timer(i)
143+
f(timer)
144+
val runTime = timer.totalTime()
145+
runTimes += runTime
146+
147+
if (outputPerIteration) {
148+
// scalastyle:off
149+
println(s"Iteration $i took ${runTime / 1000} microseconds")
150+
// scalastyle:on
151+
}
152+
i += 1
153+
}
154+
// scalastyle:off
155+
println(s" Stopped after $i iterations, ${runTimes.sum / 1000000} ms")
156+
// scalastyle:on
157+
val best = runTimes.min
158+
val avg = runTimes.sum / runTimes.size
159+
Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0)
160+
}
100161
}
101162

102163
private[spark] object Benchmark {
@@ -161,30 +222,4 @@ private[spark] object Benchmark {
161222
val osVersion = System.getProperty("os.version")
162223
s"${vmName} ${runtimeVersion} on ${osName} ${osVersion}"
163224
}
164-
165-
/**
166-
* Runs a single function `f` for iters, returning the average time the function took and
167-
* the rate of the function.
168-
*/
169-
def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Timer => Unit): Result = {
170-
val runTimes = ArrayBuffer[Long]()
171-
for (i <- 0 until iters + 1) {
172-
val timer = new Benchmark.Timer(i)
173-
f(timer)
174-
val runTime = timer.totalTime()
175-
if (i > 0) {
176-
runTimes += runTime
177-
}
178-
179-
if (outputPerIteration) {
180-
// scalastyle:off
181-
println(s"Iteration $i took ${runTime / 1000} microseconds")
182-
// scalastyle:on
183-
}
184-
}
185-
val best = runTimes.min
186-
val avg = runTimes.sum / iters
187-
Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0)
188-
}
189225
}
190-

0 commit comments

Comments
 (0)