|
17 | 17 |
|
18 | 18 | package org.apache.spark.examples.mllib |
19 | 19 |
|
| 20 | +import org.apache.spark.mllib.linalg.Vectors |
20 | 21 | import org.apache.spark.{SparkConf, SparkContext} |
21 | 22 | import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, CoordinateMatrix, RowMatrix} |
22 | 23 |
|
23 | 24 | /** |
24 | 25 | * Compute the similar columns of a matrix, using cosine similarity. |
| 26 | + * |
| 27 | + * The input matrix must be stored in row-oriented dense format, one line per row with its entries |
| 28 | + * separated by space. For example, |
| 29 | + * {{{ |
| 30 | + * 0.5 1.0 |
| 31 | + * 2.0 3.0 |
| 32 | + * 4.0 5.0 |
| 33 | + * }}} |
| 34 | + * represents a 3-by-2 matrix, whose first row is (0.5, 1.0). |
| 35 | + * |
| 36 | + * Example invocation: |
| 37 | + * |
| 38 | + * bin/run-example org.apache.spark.examples.mllib.CosineSimilarity \ |
| 39 | + * data/mllib/sample_svm_data.txt 0.1 |
25 | 40 | */ |
26 | 41 | object CosineSimilarity { |
27 | 42 | def main(args: Array[String]) { |
| 43 | + if (args.length != 2) { |
| 44 | + System.err.println("Usage: CosineSimilarity <input> <threshold>") |
| 45 | + System.exit(1) |
| 46 | + } |
| 47 | + |
28 | 48 | val conf = new SparkConf().setAppName("CosineSimilarity") |
29 | 49 | val sc = new SparkContext(conf) |
30 | 50 |
|
31 | | - // Number of rows |
32 | | - val M = 1000 |
33 | | - // Number of columns |
34 | | - val U = 1000 |
35 | | - // Number of nonzeros per row |
36 | | - val NNZ = 10 |
37 | | - // Number of partitions for data |
38 | | - val NUMCHUNKS = 4 |
39 | | - |
40 | | - // Create data |
41 | | - val R = sc.parallelize(0 until M, NUMCHUNKS).flatMap{i => |
42 | | - val inds = new scala.collection.mutable.TreeSet[Int]() |
43 | | - while (inds.size < NNZ) { |
44 | | - inds += scala.util.Random.nextInt(U) |
45 | | - } |
46 | | - inds.toArray.map(j => MatrixEntry(i, j, scala.math.random)) |
| 51 | + // Load and parse the data file. |
| 52 | + val rows = sc.textFile(args(0)).map { line => |
| 53 | + val values = line.split(' ').map(_.toDouble) |
| 54 | + Vectors.dense(values) |
47 | 55 | } |
| 56 | + val mat = new RowMatrix(rows) |
48 | 57 |
|
49 | | - val mat = new CoordinateMatrix(R, M, U).toRowMatrix() |
| 58 | + val threshold = args(1).toDouble |
50 | 59 |
|
51 | 60 | // Compute similar columns perfectly, with brute force. |
52 | | - val simsPerfect = mat.columnSimilarities() |
| 61 | + val simsPerfect = mat.columnSimilarities().entries.collect |
| 62 | + |
| 63 | + // Compute similar columns with estimation focusing on pairs more similar than threshold |
| 64 | + val simsEstimate = mat.columnSimilarities(threshold).entries.collect |
53 | 65 |
|
54 | | - println("Pairwise similarities are: " + simsPerfect.entries.collect.mkString(", ")) |
| 66 | + val n = mat.numCols().toInt |
| 67 | + val real = Array.ofDim[Double](n, n) |
| 68 | + val est = Array.ofDim[Double](n, n) |
| 69 | + for (entry <- simsPerfect) { |
| 70 | + real(entry.i.toInt)(entry.j.toInt) = entry.value |
| 71 | + } |
| 72 | + for (entry <- simsEstimate) { |
| 73 | + est(entry.i.toInt)(entry.j.toInt) = entry.value |
| 74 | + } |
55 | 75 |
|
56 | | - // Compute similar columns with estimation focusing on pairs more similar than 0.8 |
57 | | - val simsEstimate = mat.columnSimilarities(0.8) |
| 76 | + val errors = Array.tabulate[Double](n, n)((i, j) => math.abs(real(i)(j) - est(i)(j))) |
| 77 | + val avgErr = errors.flatten.sum / (n * (n - 1) / 2) |
58 | 78 |
|
59 | | - println("Estimated pairwise similarities are: " + simsEstimate.entries.collect.mkString(", ")) |
| 79 | + println(s"Average error in estimate is: $avgErr") |
60 | 80 |
|
61 | 81 | sc.stop() |
62 | 82 | } |
|
0 commit comments