Skip to content

Commit 4e9aed0

Browse files
committed
fix test suite
1 parent 95d8fb9 commit 4e9aed0

File tree

1 file changed

+33
-44
lines changed

1 file changed

+33
-44
lines changed

mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialMapperSuite.scala

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,16 @@ import org.scalatest.FunSuite
2222
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
2323
import org.apache.spark.mllib.util.MLlibTestSparkContext
2424
import org.apache.spark.mllib.util.TestingUtils._
25-
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
25+
import org.apache.spark.sql.{Row, SQLContext}
26+
import org.scalatest.exceptions.TestFailedException
2627

2728
class PolynomialMapperSuite extends FunSuite with MLlibTestSparkContext {
2829

29-
def collectResult(result: DataFrame): Array[Vector] = {
30-
result.select("poly_features").collect().map {
31-
case Row(features: Vector) => features
32-
}
33-
}
34-
35-
def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
36-
assert((lhs, rhs).zipped.forall {
37-
case (v1: DenseVector, v2: DenseVector) => true
38-
case (v1: SparseVector, v2: SparseVector) => true
39-
case _ => false
40-
}, "The vector type should be preserved after normalization.")
41-
}
30+
@transient var sqlContext: SQLContext = _
4231

43-
def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
44-
assert((lhs, rhs).zipped.forall { (vector1, vector2) =>
45-
vector1 ~== vector2 absTol 1E-1
46-
}, "The vector value is not correct after normalization.")
32+
override def beforeAll(): Unit = {
33+
super.beforeAll()
34+
sqlContext = new SQLContext(sc)
4735
}
4836

4937
test("Polynomial expansion with default parameter") {
@@ -55,28 +43,27 @@ class PolynomialMapperSuite extends FunSuite with MLlibTestSparkContext {
5543
Vectors.sparse(3, Seq())
5644
)
5745

58-
val sqlContext = new SQLContext(sc)
59-
val dataFrame = sqlContext
60-
.createDataFrame(sc.parallelize(data, 2).map(Tuple1.apply)).toDF("features")
61-
62-
val polynomialMapper = new PolynomialMapper()
63-
.setInputCol("features")
64-
.setOutputCol("poly_features")
65-
6646
val twoDegreeExpansion: Array[Vector] = Array(
6747
Vectors.sparse(9, Array(0, 1, 3, 4, 6), Array(-2.0, 2.3, 4.0, -4.6, 5.29)),
6848
Vectors.dense(-2.0, 2.3, 4.0, -4.6, 5.29),
6949
Vectors.dense(Array.fill[Double](9)(0.0)),
7050
Vectors.dense(0.6, -1.1, -3.0, 0.36, -0.66, -1.8, 1.21, 3.3, 9.0),
7151
Vectors.sparse(9, Array.empty[Int], Array.empty[Double]))
7252

73-
val result = collectResult(polynomialMapper.transform(dataFrame))
74-
75-
println(result.mkString("\n"))
53+
val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected")
7654

77-
assertTypeOfVector(data, result)
78-
79-
assertValues(result, twoDegreeExpansion)
55+
val polynomialMapper = new PolynomialMapper()
56+
.setInputCol("features")
57+
.setOutputCol("polyFeatures")
58+
59+
polynomialMapper.transform(df).select("polyFeatures", "expected").collect().foreach {
60+
case Row(expanded: DenseVector, expected: DenseVector) =>
61+
assert(expanded ~== expected absTol 1e-1)
62+
case Row(expanded: SparseVector, expected: SparseVector) =>
63+
assert(expanded ~== expected absTol 1e-1)
64+
case _ =>
65+
throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
66+
}
8067
}
8168

8269
test("Polynomial expansion with setter") {
@@ -88,15 +75,6 @@ class PolynomialMapperSuite extends FunSuite with MLlibTestSparkContext {
8875
Vectors.sparse(3, Seq())
8976
)
9077

91-
val sqlContext = new SQLContext(sc)
92-
val dataFrame = sqlContext
93-
.createDataFrame(sc.parallelize(data, 2).map(Tuple1.apply)).toDF("features")
94-
95-
val polynomialMapper = new PolynomialMapper()
96-
.setInputCol("features")
97-
.setOutputCol("poly_features")
98-
.setDegree(3)
99-
10078
val threeDegreeExpansion: Array[Vector] = Array(
10179
Vectors.sparse(19, Array(0, 1, 3, 4, 6, 9, 10, 12, 15),
10280
Array(-2.0, 2.3, 4.0, -4.6, 5.29, -8.0, 9.2, -10.58, 12.167)),
@@ -106,11 +84,22 @@ class PolynomialMapperSuite extends FunSuite with MLlibTestSparkContext {
10684
1.98, 5.4, -1.33, -3.63, -9.9, -27.0),
10785
Vectors.sparse(19, Array.empty[Int], Array.empty[Double]))
10886

109-
val result = collectResult(polynomialMapper.transform(dataFrame))
11087

111-
assertTypeOfVector(data, result)
88+
val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected")
11289

113-
assertValues(result, threeDegreeExpansion)
90+
val polynomialMapper = new PolynomialMapper()
91+
.setInputCol("features")
92+
.setOutputCol("polyFeatures")
93+
.setDegree(3)
94+
95+
polynomialMapper.transform(df).select("polyFeatures", "expected").collect().foreach {
96+
case Row(expanded: DenseVector, expected: DenseVector) =>
97+
assert(expanded ~== expected absTol 1e-1)
98+
case Row(expanded: SparseVector, expected: SparseVector) =>
99+
assert(expanded ~== expected absTol 1e-1)
100+
case _ =>
101+
throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
102+
}
114103
}
115104
}
116105

0 commit comments

Comments
 (0)