Skip to content

Commit 4ebf34e

Browse files
committed
add test suite of polynomial expansion
1 parent 372227c commit 4ebf34e

File tree

1 file changed

+103
-0
lines changed

1 file changed

+103
-0
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
23+
import org.apache.spark.mllib.util.MLlibTestSparkContext
24+
import org.apache.spark.mllib.util.TestingUtils._
25+
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
26+
27+
private case class DataSet(features: Vector)
28+
29+
class PolynomialMapperSuite extends FunSuite with MLlibTestSparkContext {
30+
31+
@transient var data: Array[Vector] = _
32+
@transient var dataFrame: DataFrame = _
33+
@transient var polynomialMapper: PolynomialMapper = _
34+
@transient var oneDegreeExpansion: Array[Vector] = _
35+
@transient var threeDegreeExpansion: Array[Vector] = _
36+
37+
override def beforeAll(): Unit = {
38+
super.beforeAll()
39+
40+
data = Array(
41+
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
42+
Vectors.dense(0.0, 0.0, 0.0),
43+
Vectors.dense(0.6, -1.1, -3.0),
44+
Vectors.sparse(3, Seq((1, 0.91), (2, 3.2))),
45+
Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))),
46+
Vectors.sparse(3, Seq())
47+
)
48+
oneDegreeExpansion = data
49+
threeDegreeExpansion = Array(
50+
Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))),
51+
Vectors.dense(0.0, 0.0, 0.0),
52+
Vectors.dense(0.184549876, -0.3383414, -0.922749378),
53+
Vectors.sparse(3, Seq((1, 0.27352993), (2, 0.96186349))),
54+
Vectors.dense(0.897906166, 0.113419726, 0.42532397),
55+
Vectors.sparse(3, Seq())
56+
)
57+
58+
val sqlContext = new SQLContext(sc)
59+
dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(DataSet))
60+
polynomialMapper = new PolynomialMapper()
61+
.setInputCol("features")
62+
.setOutputCol("poly_features")
63+
}
64+
65+
def collectResult(result: DataFrame): Array[Vector] = {
66+
result.select("poly_features").collect().map {
67+
case Row(features: Vector) => features
68+
}
69+
}
70+
71+
def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
72+
assert((lhs, rhs).zipped.forall {
73+
case (v1: DenseVector, v2: DenseVector) => true
74+
case (v1: SparseVector, v2: SparseVector) => true
75+
case _ => false
76+
}, "The vector type should be preserved after normalization.")
77+
}
78+
79+
def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
80+
assert((lhs, rhs).zipped.forall { (vector1, vector2) =>
81+
vector1 ~== vector2 absTol 1E-5
82+
}, "The vector value is not correct after normalization.")
83+
}
84+
85+
test("Polynomial expansion with default parameter") {
86+
val result = collectResult(polynomialMapper.transform(dataFrame))
87+
88+
assertTypeOfVector(data, result)
89+
90+
assertValues(result, oneDegreeExpansion)
91+
}
92+
93+
test("Polynomial expansion with setter") {
94+
polynomialMapper.setDegree(3)
95+
96+
val result = collectResult(polynomialMapper.transform(dataFrame))
97+
98+
assertTypeOfVector(data, result)
99+
100+
assertValues(result, threeDegreeExpansion)
101+
}
102+
}
103+

0 commit comments

Comments
 (0)