Skip to content

Commit c4651bb

Browse files
committed
remove row-wise APIs and refine code
1 parent 1338ea1 commit c4651bb

File tree

2 files changed

+52
-85
lines changed

2 files changed

+52
-85
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import breeze.linalg.{Vector => BV, DenseVector => BDV}
2121
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2222
import org.apache.spark.mllib.util.MLUtils._
2323
import org.apache.spark.rdd.RDD
24-
import breeze.linalg._
2524

2625
/**
2726
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
@@ -30,30 +29,6 @@ import breeze.linalg._
3029
*/
3130
class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
3231

33-
/**
34-
* Compute the mean of each `Vector` in the RDD.
35-
*/
36-
def rowMeans(): RDD[Double] = {
37-
self.map(x => x.toArray.sum / x.size)
38-
}
39-
40-
/**
41-
* Compute the norm-2 of each `Vector` in the RDD.
42-
*/
43-
def rowNorm2(): RDD[Double] = {
44-
self.map(x => math.sqrt(x.toArray.map(x => x*x).sum))
45-
}
46-
47-
/**
48-
* Compute the standard deviation of each `Vector` in the RDD.
49-
*/
50-
def rowSDs(): RDD[Double] = {
51-
val means = self.rowMeans()
52-
self.zip(means)
53-
.map{ case(x, m) => x.toBreeze - m }
54-
.map{ x => math.sqrt(x.toArray.map(x => x*x).sum / x.size) }
55-
}
56-
5732
/**
5833
* Compute the mean of each column in the RDD.
5934
*/
@@ -137,11 +112,6 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
137112
*/
138113
def minOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(!cmp(_, _))
139114

140-
/**
141-
* Filter the vectors whose standard deviation is not zero.
142-
*/
143-
def rowShrink(): RDD[Vector] = self.zip(self.rowSDs()).filter(_._2 != 0.0).map(_._1)
144-
145115
/**
146116
* Filter each column of the RDD whose standard deviation is not zero.
147117
*/
@@ -163,34 +133,66 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
163133
}
164134
}
165135

166-
def parallelMeanAndVar(size: Int): (Vector, Vector, Double, Vector, Vector, Vector) = {
167-
val statistics = self.map(_.toBreeze).aggregate((BV.zeros[Double](size), BV.zeros[Double](size), 0.0, BV.zeros[Double](size), BV.fill(size){Double.MinValue}, BV.fill(size){Double.MaxValue}))(
136+
/**
137+
* Compute full column-wise statistics for the RDD, including
138+
* {{{
139+
* Mean: Vector,
140+
* Variance: Vector,
141+
* Count: Double,
142+
* Non-zero count: Vector,
143+
* Maximum elements: Vector,
144+
* Minimum elements: Vector.
145+
* }}},
146+
* with the size of Vector as input parameter.
147+
*/
148+
def statistics(size: Int): (Vector, Vector, Double, Vector, Vector, Vector) = {
149+
val results = self.map(_.toBreeze).aggregate((
150+
BV.zeros[Double](size),
151+
BV.zeros[Double](size),
152+
0.0,
153+
BV.zeros[Double](size),
154+
BV.fill(size){Double.MinValue},
155+
BV.fill(size){Double.MaxValue}))(
168156
seqOp = (c, v) => (c, v) match {
169-
case ((prevMean, prevM2n, cnt, nnz, maxVec, minVec), currData) =>
157+
case ((prevMean, prevM2n, cnt, nnzVec, maxVec, minVec), currData) =>
170158
val currMean = ((prevMean :* cnt) + currData) :/ (cnt + 1.0)
171-
val nonZeroCnt = Vectors.sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0))).toBreeze
159+
val nonZeroCnt = Vectors
160+
.sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0))).toBreeze
172161
currData.activeIterator.foreach { case (id, value) =>
173162
if (maxVec(id) < value) maxVec(id) = value
174163
if (minVec(id) > value) minVec(id) = value
175164
}
176-
(currMean, prevM2n + ((currData - prevMean) :* (currData - currMean)), cnt + 1.0, nnz + nonZeroCnt, maxVec, minVec)
165+
(currMean,
166+
prevM2n + ((currData - prevMean) :* (currData - currMean)),
167+
cnt + 1.0,
168+
nnzVec + nonZeroCnt,
169+
maxVec,
170+
minVec)
177171
},
178172
combOp = (lhs, rhs) => (lhs, rhs) match {
179-
case ((lhsMean, lhsM2n, lhsCnt, lhsNNZ, lhsMax, lhsMin), (rhsMean, rhsM2n, rhsCnt, rhsNNZ, rhsMax, rhsMin)) =>
180-
val totalCnt = lhsCnt + rhsCnt
181-
val totalMean = (lhsMean :* lhsCnt) + (rhsMean :* rhsCnt) :/ totalCnt
182-
val deltaMean = rhsMean - lhsMean
183-
val totalM2n = lhsM2n + rhsM2n + (((deltaMean :* deltaMean) :* (lhsCnt * rhsCnt)) :/ totalCnt)
184-
rhsMax.activeIterator.foreach { case (id, value) =>
185-
if (lhsMax(id) < value) lhsMax(id) = value
186-
}
187-
rhsMin.activeIterator.foreach { case (id, value) =>
188-
if (lhsMin(id) > value) lhsMin(id) = value
189-
}
190-
(totalMean, totalM2n, totalCnt, lhsNNZ + rhsNNZ, lhsMax, lhsMin)
173+
case (
174+
(lhsMean, lhsM2n, lhsCnt, lhsNNZ, lhsMax, lhsMin),
175+
(rhsMean, rhsM2n, rhsCnt, rhsNNZ, rhsMax, rhsMin)) =>
176+
val totalCnt = lhsCnt + rhsCnt
177+
val totalMean = (lhsMean :* lhsCnt) + (rhsMean :* rhsCnt) :/ totalCnt
178+
val deltaMean = rhsMean - lhsMean
179+
val totalM2n =
180+
lhsM2n + rhsM2n + (((deltaMean :* deltaMean) :* (lhsCnt * rhsCnt)) :/ totalCnt)
181+
rhsMax.activeIterator.foreach { case (id, value) =>
182+
if (lhsMax(id) < value) lhsMax(id) = value
183+
}
184+
rhsMin.activeIterator.foreach { case (id, value) =>
185+
if (lhsMin(id) > value) lhsMin(id) = value
186+
}
187+
(totalMean, totalM2n, totalCnt, lhsNNZ + rhsNNZ, lhsMax, lhsMin)
191188
}
192189
)
193190

194-
(Vectors.fromBreeze(statistics._1), Vectors.fromBreeze(statistics._2 :/ statistics._3), statistics._3, Vectors.fromBreeze(statistics._4), Vectors.fromBreeze(statistics._5), Vectors.fromBreeze(statistics._6))
191+
(Vectors.fromBreeze(results._1),
192+
Vectors.fromBreeze(results._2 :/ results._3),
193+
results._3,
194+
Vectors.fromBreeze(results._4),
195+
Vectors.fromBreeze(results._5),
196+
Vectors.fromBreeze(results._6))
195197
}
196198
}

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
3131
Vectors.dense(7.0, 8.0, 9.0)
3232
)
3333

34-
val rowMeans = Array(2.0, 5.0, 8.0)
35-
val rowNorm2 = Array(math.sqrt(14.0), math.sqrt(77.0), math.sqrt(194.0))
36-
val rowSDs = Array(math.sqrt(2.0 / 3.0), math.sqrt(2.0 / 3.0), math.sqrt(2.0 / 3.0))
37-
3834
val colMeans = Array(4.0, 5.0, 6.0)
3935
val colNorm2 = Array(math.sqrt(66.0), math.sqrt(93.0), math.sqrt(126.0))
4036
val colSDs = Array(math.sqrt(6.0), math.sqrt(6.0), math.sqrt(6.0))
@@ -49,35 +45,12 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
4945
Vectors.dense(7.0, 8.0, 0.0)
5046
)
5147

52-
val rowShrinkData = Array(
53-
Vectors.dense(1.0, 2.0, 0.0),
54-
Vectors.dense(7.0, 8.0, 0.0)
55-
)
56-
5748
val colShrinkData = Array(
5849
Vectors.dense(1.0, 2.0),
5950
Vectors.dense(0.0, 0.0),
6051
Vectors.dense(7.0, 8.0)
6152
)
6253

63-
test("rowMeans") {
64-
val data = sc.parallelize(localData, 2)
65-
assert(equivVector(Vectors.dense(data.rowMeans().collect()), Vectors.dense(rowMeans)),
66-
"Row means do not match.")
67-
}
68-
69-
test("rowNorm2") {
70-
val data = sc.parallelize(localData, 2)
71-
assert(equivVector(Vectors.dense(data.rowNorm2().collect()), Vectors.dense(rowNorm2)),
72-
"Row norm2s do not match.")
73-
}
74-
75-
test("rowSDs") {
76-
val data = sc.parallelize(localData, 2)
77-
assert(equivVector(Vectors.dense(data.rowSDs().collect()), Vectors.dense(rowSDs)),
78-
"Row SDs do not match.")
79-
}
80-
8154
test("colMeans") {
8255
val data = sc.parallelize(localData, 2)
8356
assert(equivVector(data.colMeans(), Vectors.dense(colMeans)),
@@ -114,14 +87,6 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
11487
)
11588
}
11689

117-
test("rowShrink") {
118-
val data = sc.parallelize(shrinkingData, 2)
119-
val res = data.rowShrink().collect()
120-
rowShrinkData.zip(res).foreach { case (lhs, rhs) =>
121-
assert(equivVector(lhs, rhs), "Row shrink error.")
122-
}
123-
}
124-
12590
test("columnShrink") {
12691
val data = sc.parallelize(shrinkingData, 2)
12792
val res = data.colShrink().collect()
@@ -130,9 +95,9 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
13095
}
13196
}
13297

133-
test("meanAndVar") {
98+
test("full-statistics") {
13499
val data = sc.parallelize(localData, 2)
135-
val (mean, sd, cnt, nnz, max, min) = data.parallelMeanAndVar(3)
100+
val (mean, sd, cnt, nnz, max, min) = data.statistics(3)
136101
assert(equivVector(mean, Vectors.dense(colMeans)), "Column means do not match.")
137102
assert(equivVector(sd, Vectors.dense(colVar)), "Column SD do not match.")
138103
assert(cnt === 3, "Column cnt do not match.")

0 commit comments

Comments
 (0)