@@ -819,7 +819,7 @@ private class BlockExpectationAggregator(
819819 @ transient private lazy val newMeansMat = DenseMatrix .zeros(numFeatures, k)
820820 @ transient private lazy val newCovsMat = DenseMatrix .zeros(covSize, k)
821821 @ transient private lazy val auxiliaryProbMat = DenseMatrix .zeros(blockSize, k)
822- @ transient private lazy val auxiliaryMat = DenseMatrix .zeros(blockSize, numFeatures)
822+ @ transient private lazy val auxiliaryPDFMat = DenseMatrix .zeros(blockSize, numFeatures)
823823 @ transient private lazy val auxiliaryCovVec = Vectors .zeros(covSize).toDense
824824
825825 @ transient private lazy val gaussians = {
@@ -852,20 +852,36 @@ private class BlockExpectationAggregator(
852852 val size = matrix.numRows
853853 require(weights.length == size)
854854
855+ val blas1 = BLAS .getBLAS(size)
856+ val blas2 = BLAS .getBLAS(k)
857+
855858 val probMat = if (blockSize == size) auxiliaryProbMat else DenseMatrix .zeros(size, k)
856859 require(! probMat.isTransposed)
857860 java.util.Arrays .fill(probMat.values, EPSILON )
858861
859- val mat = if (blockSize == size) auxiliaryMat else DenseMatrix .zeros(size, numFeatures)
862+ val pdfMat = if (blockSize == size) auxiliaryPDFMat else DenseMatrix .zeros(size, numFeatures)
863+ val probSumVec = Vectors .zeros(size).toDense
860864 var j = 0
861- val blas1 = BLAS .getBLAS(size)
862865 while (j < k) {
863- val pdfVec = gaussians(j).pdf(matrix, mat)
864- blas1.daxpy(size, bcWeights.value(j), pdfVec.values, 0 , 1 ,
865- probMat.values, j * size, 1 )
866+ val pdfVec = gaussians(j).pdf(matrix, pdfMat)
867+ val w = bcWeights.value(j)
868+ blas1.daxpy(size, w, pdfVec.values, 1 , probSumVec.values, 1 )
869+ blas1.daxpy(size, w, pdfVec.values, 0 , 1 , probMat.values, j * size, 1 )
866870 j += 1
867871 }
868872
873+ var i = 0
874+ while (i < size) {
875+ val probSum = probSumVec(i)
876+ val weight = weights(i)
877+ blas2.dscal(k, weight / probSum, probMat.values, i, size)
878+ blas2.daxpy(k, 1.0 , probMat.values, i, size, newWeights, 0 , 1 )
879+ newLogLikelihood += math.log(probSum) * weight
880+ i += 1
881+ }
882+
883+ BLAS .gemm(1.0 , matrix.transpose, probMat, 1.0 , newMeansMat)
884+
869885 // compute the cov vector for each row vector
870886 val covVec = auxiliaryCovVec
871887 val covVecIter = matrix match {
@@ -886,18 +902,11 @@ private class BlockExpectationAggregator(
886902 }
887903 }
888904
889- val blas2 = BLAS .getBLAS(k)
890- covVecIter.zip(weights.iterator).zipWithIndex.foreach {
891- case ((covVec, weight), i) =>
892- val probSum = blas2.dasum(k, probMat.values, i, size)
893- blas2.dscal(k, weight / probSum, probMat.values, i, size)
894- blas2.daxpy(k, 1.0 , probMat.values, i, size, newWeights, 0 , 1 )
895- BLAS .nativeBLAS.dger(covSize, k, 1.0 , covVec.values, 0 , 1 ,
896- probMat.values, i, size, newCovsMat.values, 0 , covSize)
897- newLogLikelihood += math.log(probSum) * weight
905+ covVecIter.zipWithIndex.foreach { case (covVec, i) =>
906+ BLAS .nativeBLAS.dger(covSize, k, 1.0 , covVec.values, 0 , 1 ,
907+ probMat.values, i, size, newCovsMat.values, 0 , covSize)
898908 }
899909
900- BLAS .gemm(1.0 , matrix.transpose, probMat, 1.0 , newMeansMat)
901910 totalCnt += size
902911
903912 this
0 commit comments