Skip to content

Commit 73d92b0

Browse files
committed
[SPARK-9018] [MLLIB] add stopwatches
Add stopwatches for easy instrumentation of MLlib algorithms. This is based on the `TimeTracker` used in decision trees. The distributed version uses Spark accumulator. jkbradley Author: Xiangrui Meng <[email protected]> Closes apache#7415 from mengxr/SPARK-9018 and squashes the following commits: 40b4347 [Xiangrui Meng] == -> === c477745 [Xiangrui Meng] address Joseph's comments f981a49 [Xiangrui Meng] add stopwatches
1 parent 6960a79 commit 73d92b0

File tree

2 files changed

+260
-0
lines changed

2 files changed

+260
-0
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.util
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.{Accumulator, SparkContext}
23+
24+
/**
25+
* Abstract class for stopwatches.
26+
*/
27+
private[spark] abstract class Stopwatch extends Serializable {
28+
29+
@transient private var running: Boolean = false
30+
private var startTime: Long = _
31+
32+
/**
33+
* Name of the stopwatch.
34+
*/
35+
val name: String
36+
37+
/**
38+
* Starts the stopwatch.
39+
* Throws an exception if the stopwatch is already running.
40+
*/
41+
def start(): Unit = {
42+
assume(!running, "start() called but the stopwatch is already running.")
43+
running = true
44+
startTime = now
45+
}
46+
47+
/**
48+
* Stops the stopwatch and returns the duration of the last session in milliseconds.
49+
* Throws an exception if the stopwatch is not running.
50+
*/
51+
def stop(): Long = {
52+
assume(running, "stop() called but the stopwatch is not running.")
53+
val duration = now - startTime
54+
add(duration)
55+
running = false
56+
duration
57+
}
58+
59+
/**
60+
* Checks whether the stopwatch is running.
61+
*/
62+
def isRunning: Boolean = running
63+
64+
/**
65+
* Returns total elapsed time in milliseconds, not counting the current session if the stopwatch
66+
* is running.
67+
*/
68+
def elapsed(): Long
69+
70+
/**
71+
* Gets the current time in milliseconds.
72+
*/
73+
protected def now: Long = System.currentTimeMillis()
74+
75+
/**
76+
* Adds input duration to total elapsed time.
77+
*/
78+
protected def add(duration: Long): Unit
79+
}
80+
81+
/**
82+
* A local [[Stopwatch]].
83+
*/
84+
private[spark] class LocalStopwatch(override val name: String) extends Stopwatch {
85+
86+
private var elapsedTime: Long = 0L
87+
88+
override def elapsed(): Long = elapsedTime
89+
90+
override protected def add(duration: Long): Unit = {
91+
elapsedTime += duration
92+
}
93+
}
94+
95+
/**
96+
* A distributed [[Stopwatch]] using Spark accumulator.
97+
* @param sc SparkContext
98+
*/
99+
private[spark] class DistributedStopwatch(
100+
sc: SparkContext,
101+
override val name: String) extends Stopwatch {
102+
103+
private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)")
104+
105+
override def elapsed(): Long = elapsedTime.value
106+
107+
override protected def add(duration: Long): Unit = {
108+
elapsedTime += duration
109+
}
110+
}
111+
112+
/**
113+
* A multiple stopwatch that contains local and distributed stopwatches.
114+
* @param sc SparkContext
115+
*/
116+
private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable {
117+
118+
private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty
119+
120+
/**
121+
* Adds a local stopwatch.
122+
* @param name stopwatch name
123+
*/
124+
def addLocal(name: String): this.type = {
125+
require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
126+
stopwatches(name) = new LocalStopwatch(name)
127+
this
128+
}
129+
130+
/**
131+
* Adds a distributed stopwatch.
132+
* @param name stopwatch name
133+
*/
134+
def addDistributed(name: String): this.type = {
135+
require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
136+
stopwatches(name) = new DistributedStopwatch(sc, name)
137+
this
138+
}
139+
140+
/**
141+
* Gets a stopwatch.
142+
* @param name stopwatch name
143+
*/
144+
def apply(name: String): Stopwatch = stopwatches(name)
145+
146+
override def toString: String = {
147+
stopwatches.values.toArray.sortBy(_.name)
148+
.map(c => s" ${c.name}: ${c.elapsed()}ms")
149+
.mkString("{\n", ",\n", "\n}")
150+
}
151+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.util
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.mllib.util.MLlibTestSparkContext
22+
23+
class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
24+
25+
private def testStopwatchOnDriver(sw: Stopwatch): Unit = {
26+
assert(sw.name === "sw")
27+
assert(sw.elapsed() === 0L)
28+
assert(!sw.isRunning)
29+
intercept[AssertionError] {
30+
sw.stop()
31+
}
32+
sw.start()
33+
Thread.sleep(50)
34+
val duration = sw.stop()
35+
assert(duration >= 50 && duration < 100) // using a loose upper bound
36+
val elapsed = sw.elapsed()
37+
assert(elapsed === duration)
38+
sw.start()
39+
Thread.sleep(50)
40+
val duration2 = sw.stop()
41+
assert(duration2 >= 50 && duration2 < 100)
42+
val elapsed2 = sw.elapsed()
43+
assert(elapsed2 === duration + duration2)
44+
sw.start()
45+
assert(sw.isRunning)
46+
intercept[AssertionError] {
47+
sw.start()
48+
}
49+
}
50+
51+
test("LocalStopwatch") {
52+
val sw = new LocalStopwatch("sw")
53+
testStopwatchOnDriver(sw)
54+
}
55+
56+
test("DistributedStopwatch on driver") {
57+
val sw = new DistributedStopwatch(sc, "sw")
58+
testStopwatchOnDriver(sw)
59+
}
60+
61+
test("DistributedStopwatch on executors") {
62+
val sw = new DistributedStopwatch(sc, "sw")
63+
val rdd = sc.parallelize(0 until 4, 4)
64+
rdd.foreach { i =>
65+
sw.start()
66+
Thread.sleep(50)
67+
sw.stop()
68+
}
69+
assert(!sw.isRunning)
70+
val elapsed = sw.elapsed()
71+
assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound
72+
}
73+
74+
test("MultiStopwatch") {
75+
val sw = new MultiStopwatch(sc)
76+
.addLocal("local")
77+
.addDistributed("spark")
78+
assert(sw("local").name === "local")
79+
assert(sw("spark").name === "spark")
80+
intercept[NoSuchElementException] {
81+
sw("some")
82+
}
83+
assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}")
84+
sw("local").start()
85+
sw("spark").start()
86+
Thread.sleep(50)
87+
sw("local").stop()
88+
Thread.sleep(50)
89+
sw("spark").stop()
90+
val localElapsed = sw("local").elapsed()
91+
val sparkElapsed = sw("spark").elapsed()
92+
assert(localElapsed >= 50 && localElapsed < 100)
93+
assert(sparkElapsed >= 100 && sparkElapsed < 200)
94+
assert(sw.toString ===
95+
s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}")
96+
val rdd = sc.parallelize(0 until 4, 4)
97+
rdd.foreach { i =>
98+
sw("local").start()
99+
sw("spark").start()
100+
Thread.sleep(50)
101+
sw("spark").stop()
102+
sw("local").stop()
103+
}
104+
val localElapsed2 = sw("local").elapsed()
105+
assert(localElapsed2 === localElapsed)
106+
val sparkElapsed2 = sw("spark").elapsed()
107+
assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600)
108+
}
109+
}

0 commit comments

Comments
 (0)