From f981a4993485907ab60b5a1456cd7963bb610363 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 14 Jul 2015 23:37:47 -0700 Subject: [PATCH 1/3] add stopwatches --- .../apache/spark/ml/util/stopwatches.scala | 152 ++++++++++++++++++ .../apache/spark/ml/util/StopwatchSuite.scala | 109 +++++++++++++ 2 files changed, 261 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala new file mode 100644 index 000000000000..89f14c1cc0da --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.mutable + +import org.apache.spark.{Accumulator, SparkContext} + +/** + * Abstract class for stopwatches. + */ +private[spark] abstract class Stopwatch extends Serializable { + + @transient private var running: Boolean = false + private var startTime: Long = _ + + /** + * Name of the stopwatch. + */ + val name: String + + /** + * Starts the stopwatch. + * Throws an exception if the stopwatch is already running. + */ + def start(): Unit = { + assume(!running, "start() called but the stopwatch is already running.") + running = true + startTime = now + } + + /** + * Stops the stopwatch and returns the duration of the last session in milliseconds. + * Throws an exception if the stopwatch is not running. + */ + def stop(): Long = { + assume(running, "stop() called but the stopwatch is not running.") + val duration = now - startTime + add(duration) + running = false + duration + } + + /** + * Checks whether the stopwatch is running. + */ + def isRunning: Boolean = running + + /** + * Returns total elapsed time in milliseconds, not counting the current session if the stopwatch + * is running. + */ + def elapsed(): Long + + /** + * Gets the current time in milliseconds. + */ + protected def now: Long = System.currentTimeMillis() + + /** + * Adds input duration to total elapsed time. + */ + protected def add(duration: Long): Unit +} + +/** + * A local [[Stopwatch]]. + */ +private[spark] class LocalStopwatch(override val name: String) extends Stopwatch { + + private var elapsedTime: Long = 0L + + override def elapsed(): Long = elapsedTime + + override protected def add(duration: Long): Unit = { + elapsedTime += duration + } +} + +/** + * A distributed [[Stopwatch]] using Spark accumulator. + * @param sc SparkContext + */ +private[spark] class DistributedStopwatch( + sc: SparkContext, + override val name: String) extends Stopwatch { + + val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)") + + override def elapsed(): Long = elapsedTime.value + + override protected def add(duration: Long): Unit = { + elapsedTime += duration + } +} + +/** + * A multiple stopwatch that contains local and distributed stopwatches. + * @param sc SparkContext + */ +private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable { + + private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty + + /** + * Adds a local stopwatch. + * @param name stopwatch name + */ + def addLocal(name: String): this.type = { + require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.") + stopwatches(name) = new LocalStopwatch(name) + this + } + + /** + * Adds a distributed stopwatch. + * @param name stopwatch name + */ + def addDistributed(name: String): this.type = { + require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.") + stopwatches(name) = new DistributedStopwatch(sc, name) + this + } + + /** + * Gets a stopwatch. + * @param name stopwatch name + */ + def apply(name: String): Stopwatch = stopwatches(name) + + override def toString: String = { + "MultiStopwatch" + + stopwatches.values.toArray.sortBy(_.name) + .map(c => s" ${c.name} -> ${c.elapsed()}ms") + .mkString("(\n", ",\n", "\n)") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala new file mode 100644 index 000000000000..b4e625168f38 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { + + private def testStopwatchOnDriver(sw: Stopwatch): Unit = { + assert(sw.name === "sw") + assert(sw.elapsed() === 0L) + assert(!sw.isRunning) + intercept[AssertionError] { + sw.stop() + } + sw.start() + Thread.sleep(50) + val duration = sw.stop() + assert(duration >= 50 && duration < 100) // using a loose upper bound + val elapsed = sw.elapsed() + assert(elapsed >= 50 && elapsed < 100) + sw.start() + Thread.sleep(50) + val duration2 = sw.stop() + assert(duration2 >= 50 && duration2 < 100) + val elapsed2 = sw.elapsed() + assert(elapsed2 >= 100 && elapsed2 < 200) + sw.start() + assert(sw.isRunning) + intercept[AssertionError] { + sw.start() + } + } + + test("LocalStopwatch") { + val sw = new LocalStopwatch("sw") + testStopwatchOnDriver(sw) + } + + test("DistributedStopwatch on driver") { + val sw = new DistributedStopwatch(sc, "sw") + testStopwatchOnDriver(sw) + } + + test("DistributedStopwatch on executors") { + val sw = new DistributedStopwatch(sc, "sw") + val rdd = sc.parallelize(0 until 4, 4) + rdd.foreach { i => + sw.start() + Thread.sleep(50) + sw.stop() + } + assert(!sw.isRunning) + val elapsed = sw.elapsed() + assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound + } + + test("MultiStopwatch") { + val sw = new MultiStopwatch(sc) + .addLocal("local") + .addDistributed("spark") + assert(sw("local").name === "local") + assert(sw("spark").name === "spark") + intercept[NoSuchElementException] { + sw("some") + } + assert(sw.toString === "MultiStopwatch(\n local -> 0ms,\n spark -> 0ms\n)") + sw("local").start() + sw("spark").start() + Thread.sleep(50) + sw("local").stop() + Thread.sleep(50) + sw("spark").stop() + val localElapsed = sw("local").elapsed() + val sparkElapsed = sw("spark").elapsed() + assert(localElapsed >= 50 && localElapsed < 100) + assert(sparkElapsed >= 100 && sparkElapsed < 200) + assert(sw.toString === + s"MultiStopwatch(\n local -> ${localElapsed}ms,\n spark -> ${sparkElapsed}ms\n)") + val rdd = sc.parallelize(0 until 4, 4) + rdd.foreach { i => + sw("local").start() + sw("spark").start() + Thread.sleep(50) + sw("spark").stop() + sw("local").stop() + } + val localElapsed2 = sw("local").elapsed() + assert(localElapsed2 === localElapsed) + val sparkElapsed2 = sw("spark").elapsed() + assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600) + } +} From c4777451108afd504bbb3a3ac631a1a8d4d6f6f8 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 15 Jul 2015 12:50:22 -0700 Subject: [PATCH 2/3] address Joseph's comments --- .../scala/org/apache/spark/ml/util/stopwatches.scala | 11 +++++------ .../org/apache/spark/ml/util/StopwatchSuite.scala | 8 ++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala index 89f14c1cc0da..5fdf878a3df7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala @@ -50,7 +50,7 @@ private[spark] abstract class Stopwatch extends Serializable { */ def stop(): Long = { assume(running, "stop() called but the stopwatch is not running.") - val duration = now - startTime + val duration = now - startTime add(duration) running = false duration @@ -100,7 +100,7 @@ private[spark] class DistributedStopwatch( sc: SparkContext, override val name: String) extends Stopwatch { - val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)") + private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)") override def elapsed(): Long = elapsedTime.value @@ -144,9 +144,8 @@ private[spark] class MultiStopwatch(@transient private val sc: SparkContext) ext def apply(name: String): Stopwatch = stopwatches(name) override def toString: String = { - "MultiStopwatch" + - stopwatches.values.toArray.sortBy(_.name) - .map(c => s" ${c.name} -> ${c.elapsed()}ms") - .mkString("(\n", ",\n", "\n)") + stopwatches.values.toArray.sortBy(_.name) + .map(c => s" ${c.name}: ${c.elapsed()}ms") + .mkString("{\n", ",\n", "\n}") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala index b4e625168f38..2c8cca9ed26a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -34,13 +34,13 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { val duration = sw.stop() assert(duration >= 50 && duration < 100) // using a loose upper bound val elapsed = sw.elapsed() - assert(elapsed >= 50 && elapsed < 100) + assert(elapsed === duration) sw.start() Thread.sleep(50) val duration2 = sw.stop() assert(duration2 >= 50 && duration2 < 100) val elapsed2 = sw.elapsed() - assert(elapsed2 >= 100 && elapsed2 < 200) + assert(elapsed2 == duration + duration2) sw.start() assert(sw.isRunning) intercept[AssertionError] { @@ -80,7 +80,7 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { intercept[NoSuchElementException] { sw("some") } - assert(sw.toString === "MultiStopwatch(\n local -> 0ms,\n spark -> 0ms\n)") + assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}") sw("local").start() sw("spark").start() Thread.sleep(50) @@ -92,7 +92,7 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { assert(localElapsed >= 50 && localElapsed < 100) assert(sparkElapsed >= 100 && sparkElapsed < 200) assert(sw.toString === - s"MultiStopwatch(\n local -> ${localElapsed}ms,\n spark -> ${sparkElapsed}ms\n)") + s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}") val rdd = sc.parallelize(0 until 4, 4) rdd.foreach { i => sw("local").start() From 40b43476dafcd42a562027740f4efe7089d0efd4 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 15 Jul 2015 20:20:33 -0700 Subject: [PATCH 3/3] == -> === --- .../test/scala/org/apache/spark/ml/util/StopwatchSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala index 2c8cca9ed26a..8df6617fe022 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -40,7 +40,7 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { val duration2 = sw.stop() assert(duration2 >= 50 && duration2 < 100) val elapsed2 = sw.elapsed() - assert(elapsed2 == duration + duration2) + assert(elapsed2 === duration + duration2) sw.start() assert(sw.isRunning) intercept[AssertionError] {