Skip to content

Commit ac96fb2

Browse files
committed
Compute approximation error, add command line.
1 parent 4533579 commit ac96fb2

File tree

1 file changed

+42
-22
lines changed

1 file changed

+42
-22
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,46 +17,66 @@
1717

1818
package org.apache.spark.examples.mllib
1919

20+
import org.apache.spark.mllib.linalg.Vectors
2021
import org.apache.spark.{SparkConf, SparkContext}
2122
import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, CoordinateMatrix, RowMatrix}
2223

2324
/**
2425
* Compute the similar columns of a matrix, using cosine similarity.
26+
*
27+
* The input matrix must be stored in row-oriented dense format, one line per row with its entries
28+
* separated by space. For example,
29+
* {{{
30+
* 0.5 1.0
31+
* 2.0 3.0
32+
* 4.0 5.0
33+
* }}}
34+
* represents a 3-by-2 matrix, whose first row is (0.5, 1.0).
35+
*
36+
* Example invocation:
37+
*
38+
* bin/run-example org.apache.spark.examples.mllib.CosineSimilarity \
39+
* data/mllib/sample_svm_data.txt 0.1
2540
*/
2641
object CosineSimilarity {
2742
def main(args: Array[String]) {
43+
if (args.length != 2) {
44+
System.err.println("Usage: CosineSimilarity <input> <threshold>")
45+
System.exit(1)
46+
}
47+
2848
val conf = new SparkConf().setAppName("CosineSimilarity")
2949
val sc = new SparkContext(conf)
3050

31-
// Number of rows
32-
val M = 1000
33-
// Number of columns
34-
val U = 1000
35-
// Number of nonzeros per row
36-
val NNZ = 10
37-
// Number of partitions for data
38-
val NUMCHUNKS = 4
39-
40-
// Create data
41-
val R = sc.parallelize(0 until M, NUMCHUNKS).flatMap{i =>
42-
val inds = new scala.collection.mutable.TreeSet[Int]()
43-
while (inds.size < NNZ) {
44-
inds += scala.util.Random.nextInt(U)
45-
}
46-
inds.toArray.map(j => MatrixEntry(i, j, scala.math.random))
51+
// Load and parse the data file.
52+
val rows = sc.textFile(args(0)).map { line =>
53+
val values = line.split(' ').map(_.toDouble)
54+
Vectors.dense(values)
4755
}
56+
val mat = new RowMatrix(rows)
4857

49-
val mat = new CoordinateMatrix(R, M, U).toRowMatrix()
58+
val threshold = args(1).toDouble
5059

5160
// Compute similar columns perfectly, with brute force.
52-
val simsPerfect = mat.columnSimilarities()
61+
val simsPerfect = mat.columnSimilarities().entries.collect
62+
63+
// Compute similar columns with estimation focusing on pairs more similar than threshold
64+
val simsEstimate = mat.columnSimilarities(threshold).entries.collect
5365

54-
println("Pairwise similarities are: " + simsPerfect.entries.collect.mkString(", "))
66+
val n = mat.numCols().toInt
67+
val real = Array.ofDim[Double](n, n)
68+
val est = Array.ofDim[Double](n, n)
69+
for (entry <- simsPerfect) {
70+
real(entry.i.toInt)(entry.j.toInt) = entry.value
71+
}
72+
for (entry <- simsEstimate) {
73+
est(entry.i.toInt)(entry.j.toInt) = entry.value
74+
}
5575

56-
// Compute similar columns with estimation focusing on pairs more similar than 0.8
57-
val simsEstimate = mat.columnSimilarities(0.8)
76+
val errors = Array.tabulate[Double](n, n)((i, j) => math.abs(real(i)(j) - est(i)(j)))
77+
val avgErr = errors.flatten.sum / (n * (n - 1) / 2)
5878

59-
println("Estimated pairwise similarities are: " + simsEstimate.entries.collect.mkString(", "))
79+
println(s"Average error in estimate is: $avgErr")
6080

6181
sc.stop()
6282
}

0 commit comments

Comments
 (0)