Skip to content

Commit 31f8907

Browse files
committed
nit
1 parent 0eb0f07 commit 31f8907

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ private class BlockExpectationAggregator(
819819
@transient private lazy val newMeansMat = DenseMatrix.zeros(numFeatures, k)
820820
@transient private lazy val newCovsMat = DenseMatrix.zeros(covSize, k)
821821
@transient private lazy val auxiliaryProbMat = DenseMatrix.zeros(blockSize, k)
822-
@transient private lazy val auxiliaryMat = DenseMatrix.zeros(blockSize, numFeatures)
822+
@transient private lazy val auxiliaryPDFMat = DenseMatrix.zeros(blockSize, numFeatures)
823823
@transient private lazy val auxiliaryCovVec = Vectors.zeros(covSize).toDense
824824

825825
@transient private lazy val gaussians = {
@@ -852,20 +852,36 @@ private class BlockExpectationAggregator(
852852
val size = matrix.numRows
853853
require(weights.length == size)
854854

855+
val blas1 = BLAS.getBLAS(size)
856+
val blas2 = BLAS.getBLAS(k)
857+
855858
val probMat = if (blockSize == size) auxiliaryProbMat else DenseMatrix.zeros(size, k)
856859
require(!probMat.isTransposed)
857860
java.util.Arrays.fill(probMat.values, EPSILON)
858861

859-
val mat = if (blockSize == size) auxiliaryMat else DenseMatrix.zeros(size, numFeatures)
862+
val pdfMat = if (blockSize == size) auxiliaryPDFMat else DenseMatrix.zeros(size, numFeatures)
863+
val probSumVec = Vectors.zeros(size).toDense
860864
var j = 0
861-
val blas1 = BLAS.getBLAS(size)
862865
while (j < k) {
863-
val pdfVec = gaussians(j).pdf(matrix, mat)
864-
blas1.daxpy(size, bcWeights.value(j), pdfVec.values, 0, 1,
865-
probMat.values, j * size, 1)
866+
val pdfVec = gaussians(j).pdf(matrix, pdfMat)
867+
val w = bcWeights.value(j)
868+
blas1.daxpy(size, w, pdfVec.values, 1, probSumVec.values, 1)
869+
blas1.daxpy(size, w, pdfVec.values, 0, 1, probMat.values, j * size, 1)
866870
j += 1
867871
}
868872

873+
var i = 0
874+
while (i < size) {
875+
val probSum = probSumVec(i)
876+
val weight = weights(i)
877+
blas2.dscal(k, weight / probSum, probMat.values, i, size)
878+
blas2.daxpy(k, 1.0, probMat.values, i, size, newWeights, 0, 1)
879+
newLogLikelihood += math.log(probSum) * weight
880+
i += 1
881+
}
882+
883+
BLAS.gemm(1.0, matrix.transpose, probMat, 1.0, newMeansMat)
884+
869885
// compute the cov vector for each row vector
870886
val covVec = auxiliaryCovVec
871887
val covVecIter = matrix match {
@@ -886,18 +902,11 @@ private class BlockExpectationAggregator(
886902
}
887903
}
888904

889-
val blas2 = BLAS.getBLAS(k)
890-
covVecIter.zip(weights.iterator).zipWithIndex.foreach {
891-
case ((covVec, weight), i) =>
892-
val probSum = blas2.dasum(k, probMat.values, i, size)
893-
blas2.dscal(k, weight / probSum, probMat.values, i, size)
894-
blas2.daxpy(k, 1.0, probMat.values, i, size, newWeights, 0, 1)
895-
BLAS.nativeBLAS.dger(covSize, k, 1.0, covVec.values, 0, 1,
896-
probMat.values, i, size, newCovsMat.values, 0, covSize)
897-
newLogLikelihood += math.log(probSum) * weight
905+
covVecIter.zipWithIndex.foreach { case (covVec, i) =>
906+
BLAS.nativeBLAS.dger(covSize, k, 1.0, covVec.values, 0, 1,
907+
probMat.values, i, size, newCovsMat.values, 0, covSize)
898908
}
899909

900-
BLAS.gemm(1.0, matrix.transpose, probMat, 1.0, newMeansMat)
901910
totalCnt += size
902911

903912
this

0 commit comments

Comments
 (0)