Skip to content

Commit 4f647ab

Browse files
MechCoderNick Pentreath
authored andcommitted
[SPARK-6227][MLLIB][PYSPARK] Implement PySpark wrappers for SVD and PCA (v2)
Add PCA and SVD to PySpark's wrappers for `RowMatrix` and `IndexedRowMatrix` (SVD only). Based on #7963, updated. ## How was this patch tested? New doc tests and unit tests. Ran all examples locally. Author: MechCoder <[email protected]> Author: Nick Pentreath <[email protected]> Closes #17621 from MLnick/SPARK-6227-pyspark-svd-pca. (cherry picked from commit db2fb84) Signed-off-by: Nick Pentreath <[email protected]>
1 parent c80242a commit 4f647ab

File tree

9 files changed

+408
-46
lines changed

9 files changed

+408
-46
lines changed

docs/mllib-dimensionality-reduction.md

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,14 @@ Refer to the [`SingularValueDecomposition` Java docs](api/java/org/apache/spark/
7676

7777
The same code applies to `IndexedRowMatrix` if `U` is defined as an
7878
`IndexedRowMatrix`.
79+
</div>
80+
<div data-lang="python" markdown="1">
81+
Refer to the [`SingularValueDecomposition` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.SingularValueDecomposition) for details on the API.
7982

80-
In order to run the above application, follow the instructions
81-
provided in the [Self-Contained
82-
Applications](quick-start.html#self-contained-applications) section of the Spark
83-
quick-start guide. Be sure to also include *spark-mllib* to your build file as
84-
a dependency.
83+
{% include_example python/mllib/svd_example.py %}
8584

85+
The same code applies to `IndexedRowMatrix` if `U` is defined as an
86+
`IndexedRowMatrix`.
8687
</div>
8788
</div>
8889

@@ -118,17 +119,21 @@ Refer to the [`PCA` Scala docs](api/scala/index.html#org.apache.spark.mllib.feat
118119

119120
The following code demonstrates how to compute principal components on a `RowMatrix`
120121
and use them to project the vectors into a low-dimensional space.
121-
The number of columns should be small, e.g, less than 1000.
122122

123123
Refer to the [`RowMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html) for details on the API.
124124

125125
{% include_example java/org/apache/spark/examples/mllib/JavaPCAExample.java %}
126126

127127
</div>
128-
</div>
129128

130-
In order to run the above application, follow the instructions
131-
provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
132-
section of the Spark
133-
quick-start guide. Be sure to also include *spark-mllib* to your build file as
134-
a dependency.
129+
<div data-lang="python" markdown="1">
130+
131+
The following code demonstrates how to compute principal components on a `RowMatrix`
132+
and use them to project the vectors into a low-dimensional space.
133+
134+
Refer to the [`RowMatrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.RowMatrix) for details on the API.
135+
136+
{% include_example python/mllib/pca_rowmatrix_example.py %}
137+
138+
</div>
139+
</div>

examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
package org.apache.spark.examples.mllib;
1919

2020
// $example on$
21-
import java.util.LinkedList;
21+
import java.util.Arrays;
22+
import java.util.List;
2223
// $example off$
2324

2425
import org.apache.spark.SparkConf;
@@ -39,28 +40,32 @@ public class JavaPCAExample {
3940
public static void main(String[] args) {
4041
SparkConf conf = new SparkConf().setAppName("PCA Example");
4142
SparkContext sc = new SparkContext(conf);
43+
JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc);
4244

4345
// $example on$
44-
double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}};
45-
LinkedList<Vector> rowsList = new LinkedList<>();
46-
for (int i = 0; i < array.length; i++) {
47-
Vector currentRow = Vectors.dense(array[i]);
48-
rowsList.add(currentRow);
49-
}
50-
JavaRDD<Vector> rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList);
46+
List<Vector> data = Arrays.asList(
47+
Vectors.sparse(5, new int[] {1, 3}, new double[] {1.0, 7.0}),
48+
Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
49+
Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
50+
);
51+
52+
JavaRDD<Vector> rows = jsc.parallelize(data);
5153

5254
// Create a RowMatrix from JavaRDD<Vector>.
5355
RowMatrix mat = new RowMatrix(rows.rdd());
5456

55-
// Compute the top 3 principal components.
56-
Matrix pc = mat.computePrincipalComponents(3);
57+
// Compute the top 4 principal components.
58+
// Principal components are stored in a local dense matrix.
59+
Matrix pc = mat.computePrincipalComponents(4);
60+
61+
// Project the rows to the linear space spanned by the top 4 principal components.
5762
RowMatrix projected = mat.multiply(pc);
5863
// $example off$
5964
Vector[] collectPartitions = (Vector[])projected.rows().collect();
6065
System.out.println("Projected vector of principal component:");
6166
for (Vector vector : collectPartitions) {
6267
System.out.println("\t" + vector);
6368
}
64-
sc.stop();
69+
jsc.stop();
6570
}
6671
}

examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
package org.apache.spark.examples.mllib;
1919

2020
// $example on$
21-
import java.util.LinkedList;
21+
import java.util.Arrays;
22+
import java.util.List;
2223
// $example off$
2324

2425
import org.apache.spark.SparkConf;
@@ -43,22 +44,22 @@ public static void main(String[] args) {
4344
JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc);
4445

4546
// $example on$
46-
double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}};
47-
LinkedList<Vector> rowsList = new LinkedList<>();
48-
for (int i = 0; i < array.length; i++) {
49-
Vector currentRow = Vectors.dense(array[i]);
50-
rowsList.add(currentRow);
51-
}
52-
JavaRDD<Vector> rows = jsc.parallelize(rowsList);
47+
List<Vector> data = Arrays.asList(
48+
Vectors.sparse(5, new int[] {1, 3}, new double[] {1.0, 7.0}),
49+
Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
50+
Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
51+
);
52+
53+
JavaRDD<Vector> rows = jsc.parallelize(data);
5354

5455
// Create a RowMatrix from JavaRDD<Vector>.
5556
RowMatrix mat = new RowMatrix(rows.rdd());
5657

57-
// Compute the top 3 singular values and corresponding singular vectors.
58-
SingularValueDecomposition<RowMatrix, Matrix> svd = mat.computeSVD(3, true, 1.0E-9d);
59-
RowMatrix U = svd.U();
60-
Vector s = svd.s();
61-
Matrix V = svd.V();
58+
// Compute the top 5 singular values and corresponding singular vectors.
59+
SingularValueDecomposition<RowMatrix, Matrix> svd = mat.computeSVD(5, true, 1.0E-9d);
60+
RowMatrix U = svd.U(); // The U factor is a RowMatrix.
61+
Vector s = svd.s(); // The singular values are stored in a local dense vector.
62+
Matrix V = svd.V(); // The V factor is a local dense matrix.
6263
// $example off$
6364
Vector[] collectPartitions = (Vector[]) U.rows().collect();
6465
System.out.println("U factor is:");
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark import SparkContext
19+
# $example on$
20+
from pyspark.mllib.linalg import Vectors
21+
from pyspark.mllib.linalg.distributed import RowMatrix
22+
# $example off$
23+
24+
if __name__ == "__main__":
25+
sc = SparkContext(appName="PythonPCAOnRowMatrixExample")
26+
27+
# $example on$
28+
rows = sc.parallelize([
29+
Vectors.sparse(5, {1: 1.0, 3: 7.0}),
30+
Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
31+
Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
32+
])
33+
34+
mat = RowMatrix(rows)
35+
# Compute the top 4 principal components.
36+
# Principal components are stored in a local dense matrix.
37+
pc = mat.computePrincipalComponents(4)
38+
39+
# Project the rows to the linear space spanned by the top 4 principal components.
40+
projected = mat.multiply(pc)
41+
# $example off$
42+
collected = projected.rows.collect()
43+
print("Projected Row Matrix of principal component:")
44+
for vector in collected:
45+
print(vector)
46+
sc.stop()
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark import SparkContext
19+
# $example on$
20+
from pyspark.mllib.linalg import Vectors
21+
from pyspark.mllib.linalg.distributed import RowMatrix
22+
# $example off$
23+
24+
if __name__ == "__main__":
25+
sc = SparkContext(appName="PythonSVDExample")
26+
27+
# $example on$
28+
rows = sc.parallelize([
29+
Vectors.sparse(5, {1: 1.0, 3: 7.0}),
30+
Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
31+
Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
32+
])
33+
34+
mat = RowMatrix(rows)
35+
36+
# Compute the top 5 singular values and corresponding singular vectors.
37+
svd = mat.computeSVD(5, computeU=True)
38+
U = svd.U # The U factor is a RowMatrix.
39+
s = svd.s # The singular values are stored in a local dense vector.
40+
V = svd.V # The V factor is a local dense matrix.
41+
# $example off$
42+
collected = U.rows.collect()
43+
print("U factor is:")
44+
for vector in collected:
45+
print(vector)
46+
print("Singular values are: %s" % s)
47+
print("V factor is:\n%s" % V)
48+
sc.stop()

examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ object PCAOnRowMatrixExample {
3939
Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
4040
Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
4141

42-
val dataRDD = sc.parallelize(data, 2)
42+
val rows = sc.parallelize(data)
4343

44-
val mat: RowMatrix = new RowMatrix(dataRDD)
44+
val mat: RowMatrix = new RowMatrix(rows)
4545

4646
// Compute the top 4 principal components.
4747
// Principal components are stored in a local dense matrix.

examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ import org.apache.spark.mllib.linalg.Vectors
2828
import org.apache.spark.mllib.linalg.distributed.RowMatrix
2929
// $example off$
3030

31+
/**
32+
* Example for SingularValueDecomposition.
33+
*/
3134
object SVDExample {
3235

3336
def main(args: Array[String]): Unit = {
@@ -41,15 +44,15 @@ object SVDExample {
4144
Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
4245
Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
4346

44-
val dataRDD = sc.parallelize(data, 2)
47+
val rows = sc.parallelize(data)
4548

46-
val mat: RowMatrix = new RowMatrix(dataRDD)
49+
val mat: RowMatrix = new RowMatrix(rows)
4750

4851
// Compute the top 5 singular values and corresponding singular vectors.
4952
val svd: SingularValueDecomposition[RowMatrix, Matrix] = mat.computeSVD(5, computeU = true)
5053
val U: RowMatrix = svd.U // The U factor is a RowMatrix.
51-
val s: Vector = svd.s // The singular values are stored in a local dense vector.
52-
val V: Matrix = svd.V // The V factor is a local dense matrix.
54+
val s: Vector = svd.s // The singular values are stored in a local dense vector.
55+
val V: Matrix = svd.V // The V factor is a local dense matrix.
5356
// $example off$
5457
val collect = U.rows.collect()
5558
println("U factor is:")

0 commit comments

Comments
 (0)