Skip to content

Commit 21b3d2a

Browse files
committed
[SPARK-11530][MLLIB] Return eigenvalues with PCA model
Add `computePrincipalComponentsAndVariance` to also compute PCA's explained variance. CC mengxr Author: Sean Owen <[email protected]> Closes #9736 from srowen/SPARK-11530.
1 parent e29704f commit 21b3d2a

File tree

7 files changed

+67
-25
lines changed

7 files changed

+67
-25
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
7373
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v}
7474
val pca = new feature.PCA(k = $(k))
7575
val pcaModel = pca.fit(input)
76-
copyValues(new PCAModel(uid, pcaModel.pc).setParent(this))
76+
copyValues(new PCAModel(uid, pcaModel.pc, pcaModel.explainedVariance).setParent(this))
7777
}
7878

7979
override def transformSchema(schema: StructType): StructType = {
@@ -105,7 +105,8 @@ object PCA extends DefaultParamsReadable[PCA] {
105105
@Experimental
106106
class PCAModel private[ml] (
107107
override val uid: String,
108-
val pc: DenseMatrix)
108+
val pc: DenseMatrix,
109+
val explainedVariance: DenseVector)
109110
extends Model[PCAModel] with PCAParams with MLWritable {
110111

111112
import PCAModel._
@@ -123,7 +124,7 @@ class PCAModel private[ml] (
123124
*/
124125
override def transform(dataset: DataFrame): DataFrame = {
125126
transformSchema(dataset.schema, logging = true)
126-
val pcaModel = new feature.PCAModel($(k), pc)
127+
val pcaModel = new feature.PCAModel($(k), pc, explainedVariance)
127128
val pcaOp = udf { pcaModel.transform _ }
128129
dataset.withColumn($(outputCol), pcaOp(col($(inputCol))))
129130
}
@@ -139,7 +140,7 @@ class PCAModel private[ml] (
139140
}
140141

141142
override def copy(extra: ParamMap): PCAModel = {
142-
val copied = new PCAModel(uid, pc)
143+
val copied = new PCAModel(uid, pc, explainedVariance)
143144
copyValues(copied, extra).setParent(parent)
144145
}
145146

@@ -152,11 +153,11 @@ object PCAModel extends MLReadable[PCAModel] {
152153

153154
private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter {
154155

155-
private case class Data(pc: DenseMatrix)
156+
private case class Data(pc: DenseMatrix, explainedVariance: DenseVector)
156157

157158
override protected def saveImpl(path: String): Unit = {
158159
DefaultParamsWriter.saveMetadata(instance, path, sc)
159-
val data = Data(instance.pc)
160+
val data = Data(instance.pc, instance.explainedVariance)
160161
val dataPath = new Path(path, "data").toString
161162
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
162163
}
@@ -169,10 +170,11 @@ object PCAModel extends MLReadable[PCAModel] {
169170
override def load(path: String): PCAModel = {
170171
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
171172
val dataPath = new Path(path, "data").toString
172-
val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
173-
.select("pc")
173+
val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
174+
sqlContext.read.parquet(dataPath)
175+
.select("pc", "explainedVariance")
174176
.head()
175-
val model = new PCAModel(metadata.uid, pc)
177+
val model = new PCAModel(metadata.uid, pc, explainedVariance)
176178
DefaultParamsReader.getAndSetParams(model, metadata)
177179
model
178180
}

mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.mllib.feature
1919

20-
import org.apache.spark.annotation.{Experimental, Since}
20+
import org.apache.spark.annotation.Since
2121
import org.apache.spark.api.java.JavaRDD
2222
import org.apache.spark.mllib.linalg._
2323
import org.apache.spark.mllib.linalg.distributed.RowMatrix
@@ -43,7 +43,8 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
4343
s"source vector size is ${sources.first().size} must be greater than k=$k")
4444

4545
val mat = new RowMatrix(sources)
46-
val pc = mat.computePrincipalComponents(k) match {
46+
val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k)
47+
val densePC = pc match {
4748
case dm: DenseMatrix =>
4849
dm
4950
case sm: SparseMatrix =>
@@ -58,7 +59,13 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
5859
s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}")
5960

6061
}
61-
new PCAModel(k, pc)
62+
val denseExplainedVariance = explainedVariance match {
63+
case dv: DenseVector =>
64+
dv
65+
case sv: SparseVector =>
66+
sv.toDense
67+
}
68+
new PCAModel(k, densePC, denseExplainedVariance)
6269
}
6370

6471
/**
@@ -77,7 +84,8 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
7784
@Since("1.4.0")
7885
class PCAModel private[spark] (
7986
@Since("1.4.0") val k: Int,
80-
@Since("1.4.0") val pc: DenseMatrix) extends VectorTransformer {
87+
@Since("1.4.0") val pc: DenseMatrix,
88+
@Since("1.6.0") val explainedVariance: DenseVector) extends VectorTransformer {
8189
/**
8290
* Transform a vector by computed Principal Components.
8391
*

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,8 @@ class RowMatrix @Since("1.0.0") (
368368
}
369369

370370
/**
371-
* Computes the top k principal components.
371+
* Computes the top k principal components and a vector of proportions of
372+
* variance explained by each principal component.
372373
* Rows correspond to observations and columns correspond to variables.
373374
* The principal components are stored a local matrix of size n-by-k.
374375
* Each column corresponds for one principal component,
@@ -379,24 +380,42 @@ class RowMatrix @Since("1.0.0") (
379380
* Note that this cannot be computed on matrices with more than 65535 columns.
380381
*
381382
* @param k number of top principal components.
382-
* @return a matrix of size n-by-k, whose columns are principal components
383+
* @return a matrix of size n-by-k, whose columns are principal components, and
384+
* a vector of values which indicate how much variance each principal component
385+
* explains
383386
*/
384-
@Since("1.0.0")
385-
def computePrincipalComponents(k: Int): Matrix = {
387+
@Since("1.6.0")
388+
def computePrincipalComponentsAndExplainedVariance(k: Int): (Matrix, Vector) = {
386389
val n = numCols().toInt
387390
require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]")
388391

389392
val Cov = computeCovariance().toBreeze.asInstanceOf[BDM[Double]]
390393

391-
val brzSvd.SVD(u: BDM[Double], _, _) = brzSvd(Cov)
394+
val brzSvd.SVD(u: BDM[Double], s: BDV[Double], _) = brzSvd(Cov)
395+
396+
val eigenSum = s.data.sum
397+
val explainedVariance = s.data.map(_ / eigenSum)
392398

393399
if (k == n) {
394-
Matrices.dense(n, k, u.data)
400+
(Matrices.dense(n, k, u.data), Vectors.dense(explainedVariance))
395401
} else {
396-
Matrices.dense(n, k, Arrays.copyOfRange(u.data, 0, n * k))
402+
(Matrices.dense(n, k, Arrays.copyOfRange(u.data, 0, n * k)),
403+
Vectors.dense(Arrays.copyOfRange(explainedVariance, 0, k)))
397404
}
398405
}
399406

407+
/**
408+
* Computes the top k principal components only.
409+
*
410+
* @param k number of top principal components.
411+
* @return a matrix of size n-by-k, whose columns are principal components
412+
* @see computePrincipalComponentsAndExplainedVariance
413+
*/
414+
@Since("1.0.0")
415+
def computePrincipalComponents(k: Int): Matrix = {
416+
computePrincipalComponentsAndExplainedVariance(k)._1
417+
}
418+
400419
/**
401420
* Computes column-wise summary statistics.
402421
*/

mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ import org.apache.spark.mllib.linalg.distributed.RowMatrix
2424
import org.apache.spark.mllib.linalg._
2525
import org.apache.spark.mllib.util.MLlibTestSparkContext
2626
import org.apache.spark.mllib.util.TestingUtils._
27-
import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel}
2827
import org.apache.spark.sql.Row
2928

3029
class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
3130

3231
test("params") {
3332
ParamsSuite.checkParams(new PCA)
3433
val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]
35-
val model = new PCAModel("pca", mat)
34+
val explainedVariance = Vectors.dense(0.5, 0.5).asInstanceOf[DenseVector]
35+
val model = new PCAModel("pca", mat, explainedVariance)
3636
ParamsSuite.checkParams(model)
3737
}
3838

@@ -76,7 +76,8 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
7676

7777
test("PCAModel read/write") {
7878
val instance = new PCAModel("myPCAModel",
79-
Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix])
79+
Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix],
80+
Vectors.dense(0.5, 0.5).asInstanceOf[DenseVector])
8081
val newInstance = testDefaultReadWrite(instance)
8182
assert(newInstance.pc === instance.pc)
8283
}

mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
3737
val pca = new PCA(k).fit(dataRDD)
3838

3939
val mat = new RowMatrix(dataRDD)
40-
val pc = mat.computePrincipalComponents(k)
40+
val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k)
4141

4242
val pca_transform = pca.transform(dataRDD).collect()
4343
val mat_multiply = mat.multiply(pc).rows.collect()
4444

4545
assert(pca_transform.toSet === mat_multiply.toSet)
46+
assert(pca.explainedVariance === explainedVariance)
4647
}
4748
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.mllib.linalg.distributed
1919

20+
import java.util.Arrays
21+
2022
import scala.util.Random
2123

2224
import breeze.numerics.abs
@@ -49,6 +51,7 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
4951
(0.0, 1.0, 0.0),
5052
(math.sqrt(2.0) / 2.0, 0.0, math.sqrt(2.0) / 2.0),
5153
(math.sqrt(2.0) / 2.0, 0.0, - math.sqrt(2.0) / 2.0))
54+
val explainedVariance = BDV(4.0 / 7.0, 3.0 / 7.0, 0.0)
5255

5356
var denseMat: RowMatrix = _
5457
var sparseMat: RowMatrix = _
@@ -201,10 +204,15 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
201204

202205
test("pca") {
203206
for (mat <- Seq(denseMat, sparseMat); k <- 1 to n) {
204-
val pc = denseMat.computePrincipalComponents(k)
207+
val (pc, expVariance) = mat.computePrincipalComponentsAndExplainedVariance(k)
205208
assert(pc.numRows === n)
206209
assert(pc.numCols === k)
207210
assertColumnEqualUpToSign(pc.toBreeze.asInstanceOf[BDM[Double]], principalComponents, k)
211+
assert(
212+
closeToZero(BDV(expVariance.toArray) -
213+
BDV(Arrays.copyOfRange(explainedVariance.data, 0, k))))
214+
// Check that this method returns the same answer
215+
assert(pc === mat.computePrincipalComponents(k))
208216
}
209217
}
210218

project/MimaExcludes.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ object MimaExcludes {
5757
// MiMa does not deal properly with sealed traits
5858
ProblemFilters.exclude[MissingMethodProblem](
5959
"org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol")
60+
) ++ Seq(
61+
// SPARK-11530
62+
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this")
6063
) ++ Seq(
6164
// SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message.
6265
// This class is marked as `private` but MiMa still seems to be confused by the change.

0 commit comments

Comments
 (0)