Skip to content

Commit 35e007e

Browse files
committed
nit
nit opt_blas opt_blas
1 parent 473e5a2 commit 35e007e

File tree

2 files changed

+40
-43
lines changed

2 files changed

+40
-43
lines changed

mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -81,41 +81,34 @@ class MultivariateGaussian @Since("2.0.0") (
8181
u - 0.5 * BLAS.dot(v, v)
8282
}
8383

84-
private[ml] def pdf(X: Matrix): Vector = {
85-
val m = X.numRows
86-
val n = X.numCols
87-
val mat = new DenseMatrix(m, n, Array.ofDim[Double](m * n))
84+
private[ml] def pdf(X: Matrix): DenseVector = {
85+
val mat = DenseMatrix.zeros(X.numRows, X.numCols)
8886
pdf(X, mat)
8987
}
9088

91-
private[ml] def pdf(X: Matrix, mat: DenseMatrix): Vector = {
89+
private[ml] def pdf(X: Matrix, mat: DenseMatrix): DenseVector = {
9290
require(!mat.isTransposed)
93-
val localU = u
94-
val localRootSigmaInvMat = rootSigmaInvMat
95-
val localRootSigmaInvMulMu = rootSigmaInvMulMu.toArray
9691

97-
BLAS.gemm(1.0, X, localRootSigmaInvMat.transpose, 0.0, mat)
98-
val arr = mat.values
92+
BLAS.gemm(1.0, X, rootSigmaInvMat.transpose, 0.0, mat)
9993
val m = mat.numRows
10094
val n = mat.numCols
10195

102-
val pdfArr = Array.ofDim[Double](m)
96+
val pdfVec = mat.multiply(rootSigmaInvMulMu)
97+
98+
val blas = BLAS.getBLAS(n)
99+
val squared1 = BLAS.dot(rootSigmaInvMulMu, rootSigmaInvMulMu)
100+
101+
val localU = u
103102
var i = 0
104103
while (i < m) {
105-
var squaredSum = 0.0
106-
var index = i
107-
var j = 0
108-
while (j < n) {
109-
val d = arr(index) - localRootSigmaInvMulMu(j)
110-
squaredSum += d * d
111-
index += m
112-
j += 1
113-
}
114-
pdfArr(i) = math.exp(localU - 0.5 * squaredSum)
104+
val squared2 = blas.ddot(n, mat.values, i, m, mat.values, i, m)
105+
val dot = pdfVec(i)
106+
val squaredSum = squared1 + squared2 - dot - dot
107+
pdfVec.values(i) = math.exp(localU - 0.5 * squaredSum)
115108
i += 1
116109
}
117110

118-
Vectors.dense(pdfArr)
111+
pdfVec
119112
}
120113

121114
/**

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,15 @@ class GaussianMixture @Since("2.0.0") (
376376

377377
/**
378378
* Set block size for stacking input data in matrices.
379+
* If blockSize == 1, then stacking will be skipped, and each vector is treated individually;
380+
* If blockSize &gt; 1, then vectors will be stacked to blocks, and high-level BLAS routines
381+
* will be used if possible (for example, GEMV instead of DOT, GEMM instead of GEMV).
382+
* Recommended size is between 10 and 1000. An appropriate choice of the block size depends
383+
* on the sparsity and dim of input datasets, the underlying BLAS implementation (for example,
384+
* f2jBLAS, OpenBLAS, intel MKL) and its configuration (for example, number of threads).
385+
* Note that existing BLAS implementations are mainly optimized for dense matrices, if the
386+
* input dataset is sparse, stacking may bring no performance gain, the worse is possible
387+
* performance regression.
379388
* Default is 1.
380389
*
381390
* @group expertSetParam
@@ -809,6 +818,7 @@ private class BlockExpectationAggregator(
809818
private lazy val newWeights = Array.ofDim[Double](k)
810819
@transient private lazy val newMeansMat = DenseMatrix.zeros(numFeatures, k)
811820
@transient private lazy val newCovsMat = DenseMatrix.zeros(covSize, k)
821+
@transient private lazy val auxiliaryProbMat = DenseMatrix.zeros(blockSize, k)
812822
@transient private lazy val auxiliaryMat = DenseMatrix.zeros(blockSize, numFeatures)
813823
@transient private lazy val auxiliaryCovVec = Vectors.zeros(covSize).toDense
814824

@@ -840,15 +850,19 @@ private class BlockExpectationAggregator(
840850
val (matrix: Matrix, weights: Array[Double]) = weightedBlock
841851
require(matrix.isTransposed)
842852
val size = matrix.numRows
843-
val weightArr = bcWeights.value
853+
require(weights.length == size)
854+
855+
val probMat = if (blockSize == size) auxiliaryProbMat else DenseMatrix.zeros(size, k)
856+
require(!probMat.isTransposed)
857+
java.util.Arrays.fill(probMat.values, EPSILON)
844858

845-
val probMat = DenseMatrix.zeros(size, k)
846859
val mat = if (blockSize == size) auxiliaryMat else DenseMatrix.zeros(size, numFeatures)
847860
var j = 0
861+
val blas1 = BLAS.getBLAS(size)
848862
while (j < k) {
849863
val pdfVec = oldGaussians(j).pdf(matrix, mat)
850-
var i = 0
851-
while (i < size) { probMat.update(i, j, EPSILON + weightArr(j) * pdfVec(i)); i += 1 }
864+
blas1.daxpy(size, bcWeights.value(j), pdfVec.values, 0, 1,
865+
probMat.values, j * size, 1)
852866
j += 1
853867
}
854868

@@ -858,7 +872,7 @@ private class BlockExpectationAggregator(
858872
case dm: DenseMatrix =>
859873
Iterator.tabulate(size) { i =>
860874
java.util.Arrays.fill(covVec.values, 0.0)
861-
// when input block is dense, directly using nativeBLAS to avoid array copy
875+
// when input block is dense, directly use nativeBLAS to avoid array copy
862876
BLAS.nativeBLAS.dspr("U", numFeatures, 1.0, dm.values, i * numFeatures, 1,
863877
covVec.values, 0)
864878
covVec
@@ -872,25 +886,15 @@ private class BlockExpectationAggregator(
872886
}
873887
}
874888

875-
val probArr = Array.ofDim[Double](k)
889+
val blas2 = BLAS.getBLAS(k)
876890
covVecIter.zip(weights.iterator).zipWithIndex.foreach {
877891
case ((covVec, weight), i) =>
878-
var j = 0
879-
var probSum = 0.0
880-
while (j < k) { probSum += probMat(i, j); j += 1 }
892+
val probSum = blas2.dasum(k, probMat.values, i, size)
893+
blas2.dscal(k, weight / probSum, probMat.values, i, size)
894+
blas2.daxpy(k, 1.0, probMat.values, i, size, newWeights, 0, 1)
895+
BLAS.nativeBLAS.dger(covSize, k, 1.0, covVec.values, 0, 1,
896+
probMat.values, i, size, newCovsMat.values, 0, covSize)
881897
newLogLikelihood += math.log(probSum) * weight
882-
883-
j = 0
884-
while (j < k) {
885-
val w = probMat(i, j) / probSum * weight
886-
newWeights(j) += w
887-
probArr(j) = w
888-
probMat.update(i, j, w)
889-
j += 1
890-
}
891-
892-
BLAS.nativeBLAS.dger(covSize, k, 1.0, covVec.values, 1,
893-
probArr, 1, newCovsMat.values, covSize)
894898
}
895899

896900
BLAS.gemm(1.0, matrix.transpose, probMat, 1.0, newMeansMat)

0 commit comments

Comments
 (0)