1717
1818package org .apache .spark .util
1919
20+ import java .io .{OutputStream , PrintStream }
21+
2022import scala .collection .mutable
2123import scala .collection .mutable .ArrayBuffer
24+ import scala .concurrent .duration ._
2225import scala .util .Try
2326
27+ import org .apache .commons .io .output .TeeOutputStream
2428import org .apache .commons .lang3 .SystemUtils
2529
2630/**
@@ -33,18 +37,37 @@ import org.apache.commons.lang3.SystemUtils
3337 *
3438 * The benchmark function takes one argument that is the iteration that's being run.
3539 *
36- * If outputPerIteration is true, the timing for each run will be printed to stdout.
40+ * @param name name of this benchmark.
41+ * @param valuesPerIteration number of values used in the test case, used to compute rows/s.
42+ * @param minNumIters the min number of iterations that will be run per case, not counting warm-up.
43+ * @param warmupTime amount of time to spend running dummy case iterations for JIT warm-up.
44+ * @param minTime further iterations will be run for each case until this time is used up.
45+ * @param outputPerIteration if true, the timing for each run will be printed to stdout.
46+ * @param output optional output stream to write benchmark results to
3747 */
3848private [spark] class Benchmark (
3949 name : String ,
4050 valuesPerIteration : Long ,
41- defaultNumIters : Int = 5 ,
42- outputPerIteration : Boolean = false ) {
51+ minNumIters : Int = 2 ,
52+ warmupTime : FiniteDuration = 2 .seconds,
53+ minTime : FiniteDuration = 2 .seconds,
54+ outputPerIteration : Boolean = false ,
55+ output : Option [OutputStream ] = None ) {
56+ import Benchmark ._
4357 val benchmarks = mutable.ArrayBuffer .empty[Benchmark .Case ]
4458
59+ val out = if (output.isDefined) {
60+ new PrintStream (new TeeOutputStream (System .out, output.get))
61+ } else {
62+ System .out
63+ }
64+
4565 /**
4666 * Adds a case to run when run() is called. The given function will be run for several
4767 * iterations to collect timing statistics.
68+ *
69+ * @param name of the benchmark case
70+ * @param numIters if non-zero, forces exactly this many iterations to be run
4871 */
4972 def addCase (name : String , numIters : Int = 0 )(f : Int => Unit ): Unit = {
5073 addTimerCase(name, numIters) { timer =>
@@ -58,9 +81,12 @@ private[spark] class Benchmark(
5881 * Adds a case with manual timing control. When the function is run, timing does not start
5982 * until timer.startTiming() is called within the given function. The corresponding
6083 * timer.stopTiming() method must be called before the function returns.
84+ *
85+ * @param name of the benchmark case
86+ * @param numIters if non-zero, forces exactly this many iterations to be run
6187 */
6288 def addTimerCase (name : String , numIters : Int = 0 )(f : Benchmark .Timer => Unit ): Unit = {
63- benchmarks += Benchmark .Case (name, f, if (numIters == 0 ) defaultNumIters else numIters)
89+ benchmarks += Benchmark .Case (name, f, numIters)
6490 }
6591
6692 /**
@@ -75,28 +101,63 @@ private[spark] class Benchmark(
75101
76102 val results = benchmarks.map { c =>
77103 println(" Running case: " + c.name)
78- Benchmark . measure(valuesPerIteration, c.numIters, outputPerIteration )(c.fn)
104+ measure(valuesPerIteration, c.numIters)(c.fn)
79105 }
80106 println
81107
82108 val firstBest = results.head.bestMs
83109 // The results are going to be processor specific so it is useful to include that.
84- println(Benchmark .getJVMOSInfo())
85- println(Benchmark .getProcessorName())
86- printf(" %-40s %16s %12s %13s %10s\n " , name + " :" , " Best/Avg Time(ms)" , " Rate(M/s)" ,
110+ out. println(Benchmark .getJVMOSInfo())
111+ out. println(Benchmark .getProcessorName())
112+ out. printf(" %-40s %16s %12s %13s %10s\n " , name + " :" , " Best/Avg Time(ms)" , " Rate(M/s)" ,
87113 " Per Row(ns)" , " Relative" )
88- println(" -" * 96 )
114+ out. println(" -" * 96 )
89115 results.zip(benchmarks).foreach { case (result, benchmark) =>
90- printf(" %-40s %16s %12s %13s %10s\n " ,
116+ out. printf(" %-40s %16s %12s %13s %10s\n " ,
91117 benchmark.name,
92118 " %5.0f / %4.0f" format (result.bestMs, result.avgMs),
93119 " %10.1f" format result.bestRate,
94120 " %6.1f" format (1000 / result.bestRate),
95121 " %3.1fX" format (firstBest / result.bestMs))
96122 }
97- println
123+ out. println
98124 // scalastyle:on
99125 }
126+
127+ /**
128+ * Runs a single function `f` for iters, returning the average time the function took and
129+ * the rate of the function.
130+ */
131+ def measure (num : Long , overrideNumIters : Int )(f : Timer => Unit ): Result = {
132+ System .gc() // ensures garbage from previous cases don't impact this one
133+ val warmupDeadline = warmupTime.fromNow
134+ while (! warmupDeadline.isOverdue) {
135+ f(new Benchmark .Timer (- 1 ))
136+ }
137+ val minIters = if (overrideNumIters != 0 ) overrideNumIters else minNumIters
138+ val minDuration = if (overrideNumIters != 0 ) 0 else minTime.toNanos
139+ val runTimes = ArrayBuffer [Long ]()
140+ var i = 0
141+ while (i < minIters || runTimes.sum < minDuration) {
142+ val timer = new Benchmark .Timer (i)
143+ f(timer)
144+ val runTime = timer.totalTime()
145+ runTimes += runTime
146+
147+ if (outputPerIteration) {
148+ // scalastyle:off
149+ println(s " Iteration $i took ${runTime / 1000 } microseconds " )
150+ // scalastyle:on
151+ }
152+ i += 1
153+ }
154+ // scalastyle:off
155+ println(s " Stopped after $i iterations, ${runTimes.sum / 1000000 } ms " )
156+ // scalastyle:on
157+ val best = runTimes.min
158+ val avg = runTimes.sum / runTimes.size
159+ Result (avg / 1000000.0 , num / (best / 1000.0 ), best / 1000000.0 )
160+ }
100161}
101162
102163private [spark] object Benchmark {
@@ -161,30 +222,4 @@ private[spark] object Benchmark {
161222 val osVersion = System .getProperty(" os.version" )
162223 s " ${vmName} ${runtimeVersion} on ${osName} ${osVersion}"
163224 }
164-
165- /**
166- * Runs a single function `f` for iters, returning the average time the function took and
167- * the rate of the function.
168- */
169- def measure (num : Long , iters : Int , outputPerIteration : Boolean )(f : Timer => Unit ): Result = {
170- val runTimes = ArrayBuffer [Long ]()
171- for (i <- 0 until iters + 1 ) {
172- val timer = new Benchmark .Timer (i)
173- f(timer)
174- val runTime = timer.totalTime()
175- if (i > 0 ) {
176- runTimes += runTime
177- }
178-
179- if (outputPerIteration) {
180- // scalastyle:off
181- println(s " Iteration $i took ${runTime / 1000 } microseconds " )
182- // scalastyle:on
183- }
184- }
185- val best = runTimes.min
186- val avg = runTimes.sum / iters
187- Result (avg / 1000000.0 , num / (best / 1000.0 ), best / 1000000.0 )
188- }
189225}
190-
0 commit comments