Skip to content

Commit d50abdf

Browse files
yanboliangcmonkey
authored andcommitted
[SPARK-17847][ML] Reduce shuffled data size of GaussianMixture & copy the implementation from mllib to ml
## What changes were proposed in this pull request? Copy `GaussianMixture` implementation from mllib to ml, then we can add new features to it. I left mllib `GaussianMixture` untouched, unlike some other algorithms to wrap the ml implementation. For the following reasons: - mllib `GaussianMixture` allows k == 1, but ml does not. - mllib `GaussianMixture` supports setting initial model, but ml does not support currently. (We will definitely add this feature for ml in the future) We can get around these issues to make mllib as a wrapper calling into ml, but I'd prefer to leave mllib untouched which can make ml clean. Meanwhile, There is a big performance improvement for `GaussianMixture` in this PR. Since the covariance matrix of multivariate gaussian distribution is symmetric, we can only store the upper triangular part of the matrix and it will greatly reduce the shuffled data size. In my test, this change will reduce shuffled data size by about 50% and accelerate the job execution. Before this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/19641622/4bb017ac-9996-11e6-8ece-83db184b620a.png) After this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/19641635/629c21fe-9996-11e6-91e9-83ab74ae0126.png) ## How was this patch tested? Existing tests and added new tests. Author: Yanbo Liang <[email protected]> Closes apache#15413 from yanboliang/spark-17847.
1 parent 5b81d70 commit d50abdf

File tree

3 files changed

+469
-36
lines changed

3 files changed

+469
-36
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 315 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ import breeze.linalg.{DenseVector => BDV}
2121
import org.apache.hadoop.fs.Path
2222

2323
import org.apache.spark.annotation.{Experimental, Since}
24+
import org.apache.spark.broadcast.Broadcast
2425
import org.apache.spark.ml.{Estimator, Model}
2526
import org.apache.spark.ml.impl.Utils.EPSILON
2627
import org.apache.spark.ml.linalg._
2728
import org.apache.spark.ml.param._
2829
import org.apache.spark.ml.param.shared._
2930
import org.apache.spark.ml.stat.distribution.MultivariateGaussian
3031
import org.apache.spark.ml.util._
31-
import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM}
3232
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
3333
Vector => OldVector, Vectors => OldVectors}
3434
import org.apache.spark.rdd.RDD
@@ -45,6 +45,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
4545

4646
/**
4747
* Number of independent Gaussians in the mixture model. Must be greater than 1. Default: 2.
48+
*
4849
* @group param
4950
*/
5051
@Since("2.0.0")
@@ -57,6 +58,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
5758

5859
/**
5960
* Validates and transforms the input schema.
61+
*
6062
* @param schema input schema
6163
* @return output schema
6264
*/
@@ -238,6 +240,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
238240

239241
/**
240242
* Compute the probability (partial assignment) for each cluster for the given data point.
243+
*
241244
* @param features Data point
242245
* @param dists Gaussians for model
243246
* @param weights Weights for each Gaussian
@@ -323,31 +326,98 @@ class GaussianMixture @Since("2.0.0") (
323326
@Since("2.0.0")
324327
def setSeed(value: Long): this.type = set(seed, value)
325328

329+
/**
330+
* Number of samples per cluster to use when initializing Gaussians.
331+
*/
332+
private val numSamples = 5
333+
326334
@Since("2.0.0")
327335
override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
328336
transformSchema(dataset.schema, logging = true)
329-
val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
330-
case Row(point: Vector) => OldVectors.fromML(point)
331-
}
332337

333-
val instr = Instrumentation.create(this, rdd)
338+
val sc = dataset.sparkSession.sparkContext
339+
val numClusters = $(k)
340+
341+
val instances: RDD[Vector] = dataset.select(col($(featuresCol))).rdd.map {
342+
case Row(features: Vector) => features
343+
}.cache()
344+
345+
// Extract the number of features.
346+
val numFeatures = instances.first().size
347+
348+
val instr = Instrumentation.create(this, instances)
334349
instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol)
350+
instr.logNumFeatures(numFeatures)
351+
352+
val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians(
353+
numClusters, numFeatures)
354+
355+
// TODO: SPARK-15785 Support users supplied initial GMM.
356+
val (weights, gaussians) = initRandom(instances, numClusters, numFeatures)
357+
358+
var logLikelihood = Double.MinValue
359+
var logLikelihoodPrev = 0.0
360+
361+
var iter = 0
362+
while (iter < $(maxIter) && math.abs(logLikelihood - logLikelihoodPrev) > $(tol)) {
363+
364+
val bcWeights = instances.sparkContext.broadcast(weights)
365+
val bcGaussians = instances.sparkContext.broadcast(gaussians)
366+
367+
// aggregate the cluster contribution for all sample points
368+
val sums = instances.treeAggregate(
369+
new ExpectationAggregator(numFeatures, bcWeights, bcGaussians))(
370+
seqOp = (c, v) => (c, v) match {
371+
case (aggregator, instance) => aggregator.add(instance)
372+
},
373+
combOp = (c1, c2) => (c1, c2) match {
374+
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
375+
})
376+
377+
bcWeights.destroy(blocking = false)
378+
bcGaussians.destroy(blocking = false)
379+
380+
/*
381+
Create new distributions based on the partial assignments
382+
(often referred to as the "M" step in literature)
383+
*/
384+
val sumWeights = sums.weights.sum
385+
386+
if (shouldDistributeGaussians) {
387+
val numPartitions = math.min(numClusters, 1024)
388+
val tuples = Seq.tabulate(numClusters) { i =>
389+
(sums.means(i), sums.covs(i), sums.weights(i))
390+
}
391+
val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, cov, weight) =>
392+
GaussianMixture.updateWeightsAndGaussians(mean, cov, weight, sumWeights)
393+
}.collect().unzip
394+
Array.copy(ws.toArray, 0, weights, 0, ws.length)
395+
Array.copy(gs.toArray, 0, gaussians, 0, gs.length)
396+
} else {
397+
var i = 0
398+
while (i < numClusters) {
399+
val (weight, gaussian) = GaussianMixture.updateWeightsAndGaussians(
400+
sums.means(i), sums.covs(i), sums.weights(i), sumWeights)
401+
weights(i) = weight
402+
gaussians(i) = gaussian
403+
i += 1
404+
}
405+
}
406+
407+
logLikelihoodPrev = logLikelihood // current becomes previous
408+
logLikelihood = sums.logLikelihood // this is the freshly computed log-likelihood
409+
iter += 1
410+
}
335411

336-
val algo = new MLlibGM()
337-
.setK($(k))
338-
.setMaxIterations($(maxIter))
339-
.setSeed($(seed))
340-
.setConvergenceTol($(tol))
341-
val parentModel = algo.run(rdd)
342-
val gaussians = parentModel.gaussians.map { case g =>
343-
new MultivariateGaussian(g.mu.asML, g.sigma.asML)
412+
val gaussianDists = gaussians.map { case (mean, covVec) =>
413+
val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values)
414+
new MultivariateGaussian(mean, cov)
344415
}
345-
val model = copyValues(new GaussianMixtureModel(uid, parentModel.weights, gaussians))
346-
.setParent(this)
416+
417+
val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this)
347418
val summary = new GaussianMixtureSummary(model.transform(dataset),
348419
$(predictionCol), $(probabilityCol), $(featuresCol), $(k))
349420
model.setSummary(Some(summary))
350-
instr.logNumFeatures(model.gaussians.head.mean.size)
351421
instr.logSuccess(model)
352422
model
353423
}
@@ -356,13 +426,242 @@ class GaussianMixture @Since("2.0.0") (
356426
override def transformSchema(schema: StructType): StructType = {
357427
validateAndTransformSchema(schema)
358428
}
429+
430+
/**
431+
* Initialize weights and corresponding gaussian distributions at random.
432+
*
433+
* We start with uniform weights, a random mean from the data, and diagonal covariance matrices
434+
* using component variances derived from the samples.
435+
*
436+
* @param instances The training instances.
437+
* @param numClusters The number of clusters.
438+
* @param numFeatures The number of features of training instance.
439+
* @return The initialized weights and corresponding gaussian distributions. Note the
440+
* covariance matrix of multivariate gaussian distribution is symmetric and
441+
* we only save the upper triangular part as a dense vector (column major).
442+
*/
443+
private def initRandom(
444+
instances: RDD[Vector],
445+
numClusters: Int,
446+
numFeatures: Int): (Array[Double], Array[(DenseVector, DenseVector)]) = {
447+
val samples = instances.takeSample(withReplacement = true, numClusters * numSamples, $(seed))
448+
val weights: Array[Double] = Array.fill(numClusters)(1.0 / numClusters)
449+
val gaussians: Array[(DenseVector, DenseVector)] = Array.tabulate(numClusters) { i =>
450+
val slice = samples.view(i * numSamples, (i + 1) * numSamples)
451+
val mean = {
452+
val v = new DenseVector(new Array[Double](numFeatures))
453+
var i = 0
454+
while (i < numSamples) {
455+
BLAS.axpy(1.0, slice(i), v)
456+
i += 1
457+
}
458+
BLAS.scal(1.0 / numSamples, v)
459+
v
460+
}
461+
/*
462+
Construct matrix where diagonal entries are element-wise
463+
variance of input vectors (computes biased variance).
464+
Since the covariance matrix of multivariate gaussian distribution is symmetric,
465+
only the upper triangular part of the matrix (column major) will be saved as
466+
a dense vector in order to reduce the shuffled data size.
467+
*/
468+
val cov = {
469+
val ss = new DenseVector(new Array[Double](numFeatures)).asBreeze
470+
slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) :^ 2.0)
471+
val diagVec = Vectors.fromBreeze(ss)
472+
BLAS.scal(1.0 / numSamples, diagVec)
473+
val covVec = new DenseVector(Array.fill[Double](
474+
numFeatures * (numFeatures + 1) / 2)(0.0))
475+
diagVec.toArray.zipWithIndex.foreach { case (v: Double, i: Int) =>
476+
covVec.values(i + i * (i + 1) / 2) = v
477+
}
478+
covVec
479+
}
480+
(mean, cov)
481+
}
482+
(weights, gaussians)
483+
}
359484
}
360485

361486
@Since("2.0.0")
362487
object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
363488

364489
@Since("2.0.0")
365490
override def load(path: String): GaussianMixture = super.load(path)
491+
492+
/**
493+
* Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when
494+
* numFeatures > 25 except for when numClusters is very small.
495+
*
496+
* @param numClusters Number of clusters
497+
* @param numFeatures Number of features
498+
*/
499+
private[clustering] def shouldDistributeGaussians(
500+
numClusters: Int,
501+
numFeatures: Int): Boolean = {
502+
((numClusters - 1.0) / numClusters) * numFeatures > 25.0
503+
}
504+
505+
/**
506+
* Convert an n * (n + 1) / 2 dimension array representing the upper triangular part of a matrix
507+
* into an n * n array representing the full symmetric matrix (column major).
508+
*
509+
* @param n The order of the n by n matrix.
510+
* @param triangularValues The upper triangular part of the matrix packed in an array
511+
* (column major).
512+
* @return A dense matrix which represents the symmetric matrix in column major.
513+
*/
514+
private[clustering] def unpackUpperTriangularMatrix(
515+
n: Int,
516+
triangularValues: Array[Double]): DenseMatrix = {
517+
val symmetricValues = new Array[Double](n * n)
518+
var r = 0
519+
var i = 0
520+
while (i < n) {
521+
var j = 0
522+
while (j <= i) {
523+
symmetricValues(i * n + j) = triangularValues(r)
524+
symmetricValues(j * n + i) = triangularValues(r)
525+
r += 1
526+
j += 1
527+
}
528+
i += 1
529+
}
530+
new DenseMatrix(n, n, symmetricValues)
531+
}
532+
533+
/**
534+
* Update the weight, mean and covariance of gaussian distribution.
535+
*
536+
* @param mean The mean of the gaussian distribution.
537+
* @param cov The covariance matrix of the gaussian distribution. Note we only
538+
* save the upper triangular part as a dense vector (column major).
539+
* @param weight The weight of the gaussian distribution.
540+
* @param sumWeights The sum of weights of all clusters.
541+
* @return The updated weight, mean and covariance.
542+
*/
543+
private[clustering] def updateWeightsAndGaussians(
544+
mean: DenseVector,
545+
cov: DenseVector,
546+
weight: Double,
547+
sumWeights: Double): (Double, (DenseVector, DenseVector)) = {
548+
BLAS.scal(1.0 / weight, mean)
549+
BLAS.spr(-weight, mean, cov)
550+
BLAS.scal(1.0 / weight, cov)
551+
val newWeight = weight / sumWeights
552+
val newGaussian = (mean, cov)
553+
(newWeight, newGaussian)
554+
}
555+
}
556+
557+
/**
558+
* ExpectationAggregator computes the partial expectation results.
559+
*
560+
* @param numFeatures The number of features.
561+
* @param bcWeights The broadcast weights for each Gaussian distribution in the mixture.
562+
* @param bcGaussians The broadcast array of Multivariate Gaussian (Normal) Distribution
563+
* in the mixture. Note only upper triangular part of the covariance
564+
* matrix of each distribution is stored as dense vector (column major)
565+
* in order to reduce shuffled data size.
566+
*/
567+
private class ExpectationAggregator(
568+
numFeatures: Int,
569+
bcWeights: Broadcast[Array[Double]],
570+
bcGaussians: Broadcast[Array[(DenseVector, DenseVector)]]) extends Serializable {
571+
572+
private val k: Int = bcWeights.value.length
573+
private var totalCnt: Long = 0L
574+
private var newLogLikelihood: Double = 0.0
575+
private val newWeights: Array[Double] = new Array[Double](k)
576+
private val newMeans: Array[DenseVector] = Array.fill(k)(
577+
new DenseVector(Array.fill[Double](numFeatures)(0.0)))
578+
private val newCovs: Array[DenseVector] = Array.fill(k)(
579+
new DenseVector(Array.fill[Double](numFeatures * (numFeatures + 1) / 2)(0.0)))
580+
581+
@transient private lazy val oldGaussians = {
582+
bcGaussians.value.map { case (mean, covVec) =>
583+
val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values)
584+
new MultivariateGaussian(mean, cov)
585+
}
586+
}
587+
588+
def count: Long = totalCnt
589+
590+
def logLikelihood: Double = newLogLikelihood
591+
592+
def weights: Array[Double] = newWeights
593+
594+
def means: Array[DenseVector] = newMeans
595+
596+
def covs: Array[DenseVector] = newCovs
597+
598+
/**
599+
* Add a new training instance to this ExpectationAggregator, update the weights,
600+
* means and covariances for each distributions, and update the log likelihood.
601+
*
602+
* @param instance The instance of data point to be added.
603+
* @return This ExpectationAggregator object.
604+
*/
605+
def add(instance: Vector): this.type = {
606+
val localWeights = bcWeights.value
607+
val localOldGaussians = oldGaussians
608+
609+
val prob = new Array[Double](k)
610+
var probSum = 0.0
611+
var i = 0
612+
while (i < k) {
613+
val p = EPSILON + localWeights(i) * localOldGaussians(i).pdf(instance)
614+
prob(i) = p
615+
probSum += p
616+
i += 1
617+
}
618+
619+
newLogLikelihood += math.log(probSum)
620+
val localNewWeights = newWeights
621+
val localNewMeans = newMeans
622+
val localNewCovs = newCovs
623+
i = 0
624+
while (i < k) {
625+
prob(i) /= probSum
626+
localNewWeights(i) += prob(i)
627+
BLAS.axpy(prob(i), instance, localNewMeans(i))
628+
BLAS.spr(prob(i), instance, localNewCovs(i))
629+
i += 1
630+
}
631+
632+
totalCnt += 1
633+
this
634+
}
635+
636+
/**
637+
* Merge another ExpectationAggregator, update the weights, means and covariances
638+
* for each distributions, and update the log likelihood.
639+
* (Note that it's in place merging; as a result, `this` object will be modified.)
640+
*
641+
* @param other The other ExpectationAggregator to be merged.
642+
* @return This ExpectationAggregator object.
643+
*/
644+
def merge(other: ExpectationAggregator): this.type = {
645+
if (other.count != 0) {
646+
totalCnt += other.totalCnt
647+
648+
val localThisNewWeights = this.newWeights
649+
val localOtherNewWeights = other.newWeights
650+
val localThisNewMeans = this.newMeans
651+
val localOtherNewMeans = other.newMeans
652+
val localThisNewCovs = this.newCovs
653+
val localOtherNewCovs = other.newCovs
654+
var i = 0
655+
while (i < k) {
656+
localThisNewWeights(i) += localOtherNewWeights(i)
657+
BLAS.axpy(1.0, localOtherNewMeans(i), localThisNewMeans(i))
658+
BLAS.axpy(1.0, localOtherNewCovs(i), localThisNewCovs(i))
659+
i += 1
660+
}
661+
newLogLikelihood += other.newLogLikelihood
662+
}
663+
this
664+
}
366665
}
367666

368667
/**

0 commit comments

Comments
 (0)