@@ -376,6 +376,15 @@ class GaussianMixture @Since("2.0.0") (
376376
377377 /**
378378 * Set block size for stacking input data in matrices.
379+ * If blockSize == 1, then stacking will be skipped, and each vector is treated individually;
380+ * If blockSize > 1, then vectors will be stacked to blocks, and high-level BLAS routines
381+ * will be used if possible (for example, GEMV instead of DOT, GEMM instead of GEMV).
382+ * Recommended size is between 10 and 1000. An appropriate choice of the block size depends
383+ * on the sparsity and dim of input datasets, the underlying BLAS implementation (for example,
384+ * f2jBLAS, OpenBLAS, intel MKL) and its configuration (for example, number of threads).
385+ * Note that existing BLAS implementations are mainly optimized for dense matrices, if the
386+ * input dataset is sparse, stacking may bring no performance gain, the worse is possible
387+ * performance regression.
379388 * Default is 1.
380389 *
381390 * @group expertSetParam
@@ -809,6 +818,7 @@ private class BlockExpectationAggregator(
809818 private lazy val newWeights = Array .ofDim[Double ](k)
810819 @ transient private lazy val newMeansMat = DenseMatrix .zeros(numFeatures, k)
811820 @ transient private lazy val newCovsMat = DenseMatrix .zeros(covSize, k)
821+ @ transient private lazy val auxiliaryProbMat = DenseMatrix .zeros(blockSize, k)
812822 @ transient private lazy val auxiliaryMat = DenseMatrix .zeros(blockSize, numFeatures)
813823 @ transient private lazy val auxiliaryCovVec = Vectors .zeros(covSize).toDense
814824
@@ -840,15 +850,19 @@ private class BlockExpectationAggregator(
840850 val (matrix : Matrix , weights : Array [Double ]) = weightedBlock
841851 require(matrix.isTransposed)
842852 val size = matrix.numRows
843- val weightArr = bcWeights.value
853+ require(weights.length == size)
854+
855+ val probMat = if (blockSize == size) auxiliaryProbMat else DenseMatrix .zeros(size, k)
856+ require(! probMat.isTransposed)
857+ java.util.Arrays .fill(probMat.values, EPSILON )
844858
845- val probMat = DenseMatrix .zeros(size, k)
846859 val mat = if (blockSize == size) auxiliaryMat else DenseMatrix .zeros(size, numFeatures)
847860 var j = 0
861+ val blas1 = BLAS .getBLAS(size)
848862 while (j < k) {
849863 val pdfVec = oldGaussians(j).pdf(matrix, mat)
850- var i = 0
851- while (i < size) { probMat.update(i , j, EPSILON + weightArr(j) * pdfVec(i)); i += 1 }
864+ blas1.daxpy(size, bcWeights.value(j), pdfVec.values, 0 , 1 ,
865+ probMat.values , j * size, 1 )
852866 j += 1
853867 }
854868
@@ -858,7 +872,7 @@ private class BlockExpectationAggregator(
858872 case dm : DenseMatrix =>
859873 Iterator .tabulate(size) { i =>
860874 java.util.Arrays .fill(covVec.values, 0.0 )
861- // when input block is dense, directly using nativeBLAS to avoid array copy
875+ // when input block is dense, directly use nativeBLAS to avoid array copy
862876 BLAS .nativeBLAS.dspr(" U" , numFeatures, 1.0 , dm.values, i * numFeatures, 1 ,
863877 covVec.values, 0 )
864878 covVec
@@ -872,25 +886,15 @@ private class BlockExpectationAggregator(
872886 }
873887 }
874888
875- val probArr = Array .ofDim[ Double ] (k)
889+ val blas2 = BLAS .getBLAS (k)
876890 covVecIter.zip(weights.iterator).zipWithIndex.foreach {
877891 case ((covVec, weight), i) =>
878- var j = 0
879- var probSum = 0.0
880- while (j < k) { probSum += probMat(i, j); j += 1 }
892+ val probSum = blas2.dasum(k, probMat.values, i, size)
893+ blas2.dscal(k, weight / probSum, probMat.values, i, size)
894+ blas2.daxpy(k, 1.0 , probMat.values, i, size, newWeights, 0 , 1 )
895+ BLAS .nativeBLAS.dger(covSize, k, 1.0 , covVec.values, 0 , 1 ,
896+ probMat.values, i, size, newCovsMat.values, 0 , covSize)
881897 newLogLikelihood += math.log(probSum) * weight
882-
883- j = 0
884- while (j < k) {
885- val w = probMat(i, j) / probSum * weight
886- newWeights(j) += w
887- probArr(j) = w
888- probMat.update(i, j, w)
889- j += 1
890- }
891-
892- BLAS .nativeBLAS.dger(covSize, k, 1.0 , covVec.values, 1 ,
893- probArr, 1 , newCovsMat.values, covSize)
894898 }
895899
896900 BLAS .gemm(1.0 , matrix.transpose, probMat, 1.0 , newMeansMat)
0 commit comments