Skip to content

Commit 2512e67

Browse files
ajtullochmengxr
authored andcommitted
LocalSparkContext for MLlib
1 parent 20d9458 commit 2512e67

File tree

10 files changed

+42
-109
lines changed

10 files changed

+42
-109
lines changed

mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.mllib.classification
1919

20+
import org.apache.spark.mllib.util.LocalSparkContext
2021
import scala.util.Random
2122
import scala.collection.JavaConversions._
2223

@@ -66,19 +67,7 @@ object LogisticRegressionSuite {
6667

6768
}
6869

69-
class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
70-
@transient private var sc: SparkContext = _
71-
72-
override def beforeAll() {
73-
sc = new SparkContext("local", "test")
74-
}
75-
76-
77-
override def afterAll() {
78-
sc.stop()
79-
System.clearProperty("spark.driver.port")
80-
}
81-
70+
class LogisticRegressionSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
8271
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
8372
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
8473
prediction != expected.label

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.mllib.classification
1919

20+
import org.apache.spark.mllib.util.LocalSparkContext
2021
import scala.util.Random
2122

2223
import org.scalatest.BeforeAndAfterAll
@@ -59,17 +60,7 @@ object NaiveBayesSuite {
5960
}
6061
}
6162

62-
class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
63-
@transient private var sc: SparkContext = _
64-
65-
override def beforeAll() {
66-
sc = new SparkContext("local", "test")
67-
}
68-
69-
override def afterAll() {
70-
sc.stop()
71-
System.clearProperty("spark.driver.port")
72-
}
63+
class NaiveBayesSuite extends FunSuite with LocalSparkContext {
7364

7465
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
7566
val numOfPredictions = predictions.zip(input).count {

mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ import org.scalatest.FunSuite
2525

2626
import org.jblas.DoubleMatrix
2727

28-
import org.apache.spark.{SparkException, SparkContext}
28+
import org.apache.spark.SparkException
2929
import org.apache.spark.mllib.regression._
30+
import org.apache.spark.mllib.util.LocalSparkContext
3031

3132
object SVMSuite {
3233

@@ -58,17 +59,7 @@ object SVMSuite {
5859

5960
}
6061

61-
class SVMSuite extends FunSuite with BeforeAndAfterAll {
62-
@transient private var sc: SparkContext = _
63-
64-
override def beforeAll() {
65-
sc = new SparkContext("local", "test")
66-
}
67-
68-
override def afterAll() {
69-
sc.stop()
70-
System.clearProperty("spark.driver.port")
71-
}
62+
class SVMSuite extends FunSuite with LocalSparkContext {
7263

7364
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
7465
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>

mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,9 @@ package org.apache.spark.mllib.clustering
2121
import org.scalatest.BeforeAndAfterAll
2222
import org.scalatest.FunSuite
2323

24-
import org.apache.spark.SparkContext
24+
import org.apache.spark.mllib.util.LocalSparkContext
2525

26-
27-
class KMeansSuite extends FunSuite with BeforeAndAfterAll {
28-
@transient private var sc: SparkContext = _
29-
30-
override def beforeAll() {
31-
sc = new SparkContext("local", "test")
32-
}
33-
34-
override def afterAll() {
35-
sc.stop()
36-
System.clearProperty("spark.driver.port")
37-
}
26+
class KMeansSuite extends FunSuite with LocalSparkContext {
3827

3928
val EPSILON = 1e-4
4029

mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.scalatest.matchers.ShouldMatchers
2626

2727
import org.apache.spark.SparkContext
2828
import org.apache.spark.mllib.regression._
29+
import org.apache.spark.mllib.util.LocalSparkContext
2930

3031
object GradientDescentSuite {
3132

@@ -62,17 +63,7 @@ object GradientDescentSuite {
6263
}
6364
}
6465

65-
class GradientDescentSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
66-
@transient private var sc: SparkContext = _
67-
68-
override def beforeAll() {
69-
sc = new SparkContext("local", "test")
70-
}
71-
72-
override def afterAll() {
73-
sc.stop()
74-
System.clearProperty("spark.driver.port")
75-
}
66+
class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
7667

7768
test("Assert the loss is decreasing.") {
7869
val nPoints = 10000

mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.util.Random
2323
import org.scalatest.BeforeAndAfterAll
2424
import org.scalatest.FunSuite
2525

26-
import org.apache.spark.SparkContext
26+
import org.apache.spark.mllib.util.LocalSparkContext
2727

2828
import org.jblas._
2929

@@ -73,17 +73,7 @@ object ALSSuite {
7373
}
7474

7575

76-
class ALSSuite extends FunSuite with BeforeAndAfterAll {
77-
@transient private var sc: SparkContext = _
78-
79-
override def beforeAll() {
80-
sc = new SparkContext("local", "test")
81-
}
82-
83-
override def afterAll() {
84-
sc.stop()
85-
System.clearProperty("spark.driver.port")
86-
}
76+
class ALSSuite extends FunSuite with LocalSparkContext {
8777

8878
test("rank-1 matrices") {
8979
testALS(50, 100, 1, 15, 0.7, 0.3)

mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,9 @@ import org.scalatest.BeforeAndAfterAll
2222
import org.scalatest.FunSuite
2323

2424
import org.apache.spark.SparkContext
25-
import org.apache.spark.mllib.util.LinearDataGenerator
25+
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
2626

27-
28-
class LassoSuite extends FunSuite with BeforeAndAfterAll {
29-
@transient private var sc: SparkContext = _
30-
31-
override def beforeAll() {
32-
sc = new SparkContext("local", "test")
33-
}
34-
35-
36-
override def afterAll() {
37-
sc.stop()
38-
System.clearProperty("spark.driver.port")
39-
}
27+
class LassoSuite extends FunSuite with LocalSparkContext {
4028

4129
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
4230
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>

mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,9 @@ import org.scalatest.BeforeAndAfterAll
2121
import org.scalatest.FunSuite
2222

2323
import org.apache.spark.SparkContext
24-
import org.apache.spark.mllib.util.LinearDataGenerator
24+
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
2525

26-
class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll {
27-
@transient private var sc: SparkContext = _
28-
29-
override def beforeAll() {
30-
sc = new SparkContext("local", "test")
31-
}
32-
33-
override def afterAll() {
34-
sc.stop()
35-
System.clearProperty("spark.driver.port")
36-
}
26+
class LinearRegressionSuite extends FunSuite with LocalSparkContext {
3727

3828
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
3929
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>

mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,10 @@ import org.scalatest.BeforeAndAfterAll
2323
import org.scalatest.FunSuite
2424

2525
import org.apache.spark.SparkContext
26-
import org.apache.spark.mllib.util.LinearDataGenerator
26+
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
2727

28-
class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
29-
@transient private var sc: SparkContext = _
3028

31-
override def beforeAll() {
32-
sc = new SparkContext("local", "test")
33-
}
34-
35-
override def afterAll() {
36-
sc.stop()
37-
System.clearProperty("spark.driver.port")
38-
}
29+
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
3930

4031
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
4132
predictions.zip(input).map { case (prediction, expected) =>
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package org.apache.spark.mllib.util
2+
3+
import org.scalatest.Suite
4+
import org.scalatest.BeforeAndAfterAll
5+
6+
import org.apache.spark.SparkContext
7+
8+
trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
9+
@transient var sc: SparkContext = _
10+
11+
override def beforeAll() {
12+
sc = new SparkContext("local", "test")
13+
super.beforeAll()
14+
}
15+
16+
override def afterAll() {
17+
if (sc != null) {
18+
sc.stop()
19+
}
20+
System.clearProperty("spark.driver.port")
21+
super.afterAll()
22+
}
23+
}

0 commit comments

Comments
 (0)