Skip to content

Commit 6fa236f

Browse files
committed
fix vector slice error
1 parent daff601 commit 6fa236f

File tree

2 files changed

+15
-25
lines changed

2 files changed

+15
-25
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialMapper.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,25 +86,23 @@ object PolynomialMapper {
8686
private def expandVector(lhs: Vector, rhs: Vector, nDim: Int, currDegree: Int): Vector = {
8787
(lhs, rhs) match {
8888
case (l: DenseVector, r: DenseVector) =>
89-
var rightVectorView = rhs.toArray
89+
val rLen = rhs.size
9090
val allExpansions = l.toArray.zipWithIndex.flatMap { case (lVal, lIdx) =>
91-
val currExpansions = rightVectorView.map(rVal => lVal * rVal)
92-
val numToRemove = numMonomials(currDegree - 1, nDim - lIdx)
93-
rightVectorView = rightVectorView.drop(numToRemove)
94-
currExpansions
91+
val numToKeep = numMonomials(currDegree - 1, nDim - lIdx)
92+
r.toArray.slice(rLen - numToKeep, rLen).map(rVal => lVal * rVal)
9593
}
9694
Vectors.dense(allExpansions)
9795

9896
case (SparseVector(lLen, lIdx, lVal), SparseVector(rLen, rIdx, rVal)) =>
9997
val len = numMonomials(currDegree, nDim)
100-
var numToRemoveCum = 0
98+
var numToKeepCum = 0
10199
val allExpansions = lVal.zip(lIdx).flatMap { case (lv, li) =>
102-
val numToRemove = numMonomials(currDegree - 1, nDim - li)
100+
val numToKeep = numMonomials(currDegree - 1, nDim - li)
103101
val currExpansions = rVal.zip(rIdx).map { case (rv, ri) =>
104-
val realIdx = ri - (rLen - numToRemove)
105-
(if (realIdx >= 0) lv * rv else 0.0, numToRemoveCum + realIdx)
102+
val realIdx = ri - (rLen - numToKeep)
103+
(if (realIdx >= 0) lv * rv else 0.0, numToKeepCum + realIdx)
106104
}
107-
numToRemoveCum += numToRemove
105+
numToKeepCum += numToKeep
108106
currExpansions
109107
}.filter(_._1 != 0.0)
110108
Vectors.sparse(len, allExpansions.map(_._2), allExpansions.map(_._1))

mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialMapperSuite.scala

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,15 @@ class PolynomialMapperSuite extends FunSuite with MLlibTestSparkContext {
4343
Vectors.dense(0.6, -1.1, -3.0),
4444
Vectors.sparse(3, Seq())
4545
)
46+
4647
oneDegreeExpansion = data
48+
4749
threeDegreeExpansion = Array(
48-
Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))),
50+
Vectors.sparse(
51+
19,Array(0,1,3,4,6,9,10,12,15),Array(-2.0,2.3,4.0,-4.6,5.29,-8.0,9.2,-10.58,12.17)),
4952
Vectors.dense(Array.fill[Double](19)(0.0)),
50-
Vectors.dense(0.184549876, -0.3383414, -0.922749378),
53+
Vectors.dense(0.6,-1.1,-3.0,0.36,-0.66,-1.8,1.21,3.3,9.0,0.216,-0.396,-1.08,0.73,1.98,5.4,
54+
-1.33,-3.63,-9.9,-27.0),
5155
Vectors.sparse(19, Seq())
5256
)
5357

@@ -74,21 +78,10 @@ class PolynomialMapperSuite extends FunSuite with MLlibTestSparkContext {
7478

7579
def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
7680
assert((lhs, rhs).zipped.forall { (vector1, vector2) =>
77-
vector1 ~== vector2 absTol 1E-5
81+
vector1 ~== vector2 absTol 1E-1
7882
}, "The vector value is not correct after normalization.")
7983
}
8084

81-
test("fake") {
82-
polynomialMapper.setDegree(3)
83-
println(polynomialMapper.getDegree)
84-
val result = collectResult(polynomialMapper.transform(dataFrame))
85-
for(r <- result) {
86-
println(r)
87-
}
88-
89-
}
90-
/*
91-
9285
test("Polynomial expansion with default parameter") {
9386
val result = collectResult(polynomialMapper.transform(dataFrame))
9487

@@ -106,6 +99,5 @@ class PolynomialMapperSuite extends FunSuite with MLlibTestSparkContext {
10699

107100
assertValues(result, threeDegreeExpansion)
108101
}
109-
*/
110102
}
111103

0 commit comments

Comments
 (0)