Skip to content

Commit ae9b94a

Browse files
committed
ColumnStatisticsAggregator doesn't merge mean correctly
1 parent 78157d4 commit ae9b94a

File tree

1 file changed

+9
-11
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed

1 file changed

+9
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,21 +136,19 @@ private class ColumnStatisticsAggregator(private val n: Int)
136136

137137
var i = 0
138138
while (i < n) {
139-
// merge mean together
140-
if (other.currMean(i) != 0.0) {
139+
if (nnz(i) + other.nnz(i) != 0.0) {
140+
// merge mean together
141141
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
142142
(nnz(i) + other.nnz(i))
143-
}
144-
// merge m2n together
145-
if (nnz(i) + other.nnz(i) != 0.0) {
143+
// merge m2n together
146144
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
147145
(nnz(i) + other.nnz(i))
148-
}
149-
if (currMax(i) < other.currMax(i)) {
150-
currMax(i) = other.currMax(i)
151-
}
152-
if (currMin(i) > other.currMin(i)) {
153-
currMin(i) = other.currMin(i)
146+
if (currMax(i) < other.currMax(i)) {
147+
currMax(i) = other.currMax(i)
148+
}
149+
if (currMin(i) > other.currMin(i)) {
150+
currMin(i) = other.currMin(i)
151+
}
154152
}
155153
i += 1
156154
}

0 commit comments

Comments
 (0)